mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .meta_registry import *
|
||||
from .metainfo import *
|
||||
from .registry import meta_register
|
||||
from .shard_metainfo import *
|
||||
|
@@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import elementwise_flop_counter
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@@ -2,9 +2,9 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
||||
from ..registry import meta_register
|
||||
@@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
||||
"""Meta information generator for binary elementwise operations
|
||||
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
|
||||
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
|
||||
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
|
||||
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
|
||||
this behavior, it is critical for better memory estimation.
|
||||
|
||||
Returns:
|
||||
|
@@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
@@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
@@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# calculate memory cost
|
||||
# TODO: use profiler to check conv temp memory
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias else activation_size([input_tensor, weight_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
|
@@ -2,9 +2,9 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
@@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
||||
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
|
||||
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
|
||||
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
|
||||
|
||||
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
|
||||
|
||||
|
@@ -3,6 +3,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
@@ -11,8 +13,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
@@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
@@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
|
||||
parameter=activation_size(weight_tensor),
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
@@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# Check dimension
|
||||
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
||||
# Dot
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
||||
# gemv case 1: matrix-vector multiplication
|
||||
# &
|
||||
# batched gemv case 1: batched matrix-vector multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||
output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
||||
# gemv case 2: vector-matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemv case 2: vector-batched matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||
[output_tensors[0].reshape(-1)])
|
||||
|
||||
@@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||
output_tensors
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||
@@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemm case 2: matrix-batched matrix multiplication
|
||||
@@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
|
||||
temp=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
|
||||
compute_size_in_bytes(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
|
||||
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
|
||||
|
||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||
# Batched matrix-batched matrix multiplication
|
||||
@@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
|
||||
|
||||
else:
|
||||
# Case 2: batch dimensions are different
|
||||
@@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
|
||||
activation_size([extended_input_0, extended_input_1]),
|
||||
temp=activation_size([extended_input_0, extended_input_1]))
|
||||
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
|
||||
compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
|
||||
|
||||
# compute cost
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
@@ -4,8 +4,6 @@ from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
@@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
@@ -77,17 +77,18 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
||||
# calculate memory cost
|
||||
# the fwd activation cost is output plus saved mean and saved inv std
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||
[input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=activation_size([mean_tensor, var_tensor]))
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||
|
||||
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
|
||||
# and saved inv std during backward phase
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=activation_size([mean_tensor, var_tensor]),
|
||||
buffer=activation_size([mean_tensor, var_tensor]))
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
@@ -131,15 +132,16 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
||||
|
||||
# memory cost
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=activation_size([running_mean, running_var]))
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
||||
temp=activation_size([running_mean, running_var]),
|
||||
buffer=activation_size([running_mean, running_var]))
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([running_mean, running_var]),
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
|
@@ -2,9 +2,9 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
@@ -52,8 +52,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
|
||||
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
|
||||
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor))
|
||||
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor))
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
|
||||
@@ -114,11 +114,11 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
# calculate memory cost
|
||||
# NOTE: the index matrix will be discarded in backward phase
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
|
||||
|
||||
# temp memory for backward is the index matrix to be discarded
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
|
||||
temp=activation_size(index_matrix))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
|
||||
temp=compute_size_in_bytes(index_matrix))
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
|
||||
|
@@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
@@ -35,11 +35,11 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
|
||||
|
||||
# memory costs
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
|
||||
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
|
||||
parameter=0,
|
||||
temp=activation_size(outputs) * bwd_mem_tmp_factor,
|
||||
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
|
||||
buffer=0)
|
||||
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
|
@@ -2,9 +2,9 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@@ -15,11 +15,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['MetaInfo']
|
||||
__all__ = ['ShardMetaInfo']
|
||||
|
||||
|
||||
class MetaInfo:
|
||||
"""MetaInfo class
|
||||
class ShardMetaInfo:
|
||||
"""ShardMetaInfo class
|
||||
This class is used to store meta info based on sharding strategy and the given
|
||||
target function.
|
||||
"""
|
||||
@@ -46,9 +46,9 @@ class MetaInfo:
|
||||
# target function
|
||||
self._target = target
|
||||
|
||||
# compute metainfo if possible
|
||||
# compute shard_metainfo if possible
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
@property
|
||||
def strategy(self) -> ShardingStrategy:
|
||||
@@ -62,13 +62,13 @@ class MetaInfo:
|
||||
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
@target.setter
|
||||
def target(self, target: Callable) -> None:
|
||||
self._target = target
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
|
||||
"""
|
||||
@@ -93,7 +93,7 @@ class MetaInfo:
|
||||
|
||||
return op_data
|
||||
|
||||
def compute_metainfo(self):
|
||||
def compute_shard_metainfo(self):
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
@@ -4,7 +4,7 @@ import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
@@ -14,15 +14,15 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
|
||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
meta_info = MetaInfo()
|
||||
meta_info = ShardMetaInfo()
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
# get mem cost for ShardMetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
# extract user that has _meta_data and extract element length
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||
@@ -36,12 +36,12 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
# get computation cost for ShardMetaInfo
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
# get tensor shape for ShardMetaInfo
|
||||
origin_sharding_spec: ShardingSpec
|
||||
target_sharding_spec: ShardingSpec
|
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||
@@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
return meta_info
|
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
"""
|
||||
@@ -65,17 +65,17 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
|
||||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||
user_node_index]
|
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo:
|
||||
# extract node_index and op_data_name
|
||||
node_index, op_data_name = node.args[2], node.args[3]
|
||||
|
||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
# this case is for all_reduce, there will be no memory cost
|
||||
meta_info = MetaInfo()
|
||||
meta_info = ShardMetaInfo()
|
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||
element_length = output_node._meta_data.element_size()
|
||||
@@ -93,7 +93,7 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
|
||||
# this case will be handled by shape consistency manager
|
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||
'tgt_spec']
|
||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
return meta_info
|
||||
|
||||
@@ -105,9 +105,9 @@ def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_di
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
||||
return gm
|
||||
|
@@ -7,7 +7,7 @@ import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
@@ -96,12 +96,12 @@ class MetaInfoProp:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
meta_info = node.best_strategy_info
|
||||
meta_info: ShardMetaInfo
|
||||
|
||||
# set data_ptr for input_tensor in MetaInfo class
|
||||
# set data_ptr for input_tensor in ShardMetaInfo class
|
||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||
|
@@ -4,7 +4,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
@@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
|
||||
|
||||
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(shape_consistency_node,
|
||||
mod_dir=user_node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
@@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(comm_spec_apply_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
|
||||
return gm
|
||||
|
||||
|
@@ -6,6 +6,7 @@ import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
@@ -74,9 +75,9 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
|
||||
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
|
||||
if hasattr(node, 'metainfo_vector'):
|
||||
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
|
||||
# attach the corresponding metainfo if node has the attribute `strategies_info`
|
||||
if hasattr(node, 'strategies_info'):
|
||||
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
@@ -172,8 +173,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
||||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||
MetaInfo(size_processing_node,
|
||||
mod_dir=node.meta['info'].mod_dir,
|
||||
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
|
@@ -6,6 +6,10 @@ import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||
@@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
@@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
||||
|
||||
|
||||
def transform_to_sharded_model(gm: ColoGraphModule,
|
||||
meta_args: Dict,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
@@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule,
|
||||
strategies_constructor,
|
||||
overlap=overlap)
|
||||
gm = runtime_apply_pass(gm)
|
||||
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
gm.recompile()
|
||||
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
@@ -243,10 +247,13 @@ def initialize_model(model: nn.Module,
|
||||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
gm.recompile()
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph,
|
||||
@@ -261,7 +268,9 @@ def initialize_model(model: nn.Module,
|
||||
if save_solver_solution:
|
||||
torch.save(solution, solution_path)
|
||||
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
|
||||
overlap)
|
||||
|
||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
if return_solution:
|
||||
|
@@ -2,8 +2,6 @@ from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
|
@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
@@ -258,7 +258,7 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
@@ -266,15 +266,15 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
strategies_info = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
metainfo = ShardMetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
strategies_info.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
setattr(self, "strategies_info", strategies_info)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
@@ -313,7 +313,7 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
@@ -321,15 +321,15 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
strategies_info = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
metainfo = ShardMetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
metainfo_vector.append(metainfo)
|
||||
strategies_info.append(metainfo)
|
||||
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
setattr(self, "strategies_info", strategies_info)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
|
@@ -137,9 +137,9 @@ class StrategiesConstructor:
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
@@ -150,9 +150,9 @@ class StrategiesConstructor:
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
@@ -163,9 +163,9 @@ class StrategiesConstructor:
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
# attach strategies_info to node
|
||||
if hasattr(handler, 'strategies_info'):
|
||||
setattr(node, 'strategies_info', handler.strategies_info)
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
|
Reference in New Issue
Block a user