mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,8 +3,6 @@ import operator
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..tensor_shard.constants import *
|
||||
|
||||
# list of inplace module
|
||||
INPLACE_MODULE = [nn.ReLU]
|
||||
|
||||
|
@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
|
||||
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
input_tensor = next(
|
||||
filter(
|
||||
lambda x:
|
||||
(x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
|
||||
args)).data
|
||||
lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
|
||||
and x.name != "softmax_dim",
|
||||
args,
|
||||
)
|
||||
).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
is_inplace = 1 if kwargs.get('inplace', False) else 0
|
||||
is_inplace = 1 if kwargs.get("inplace", False) else 0
|
||||
|
||||
flop_counter = elementwise_flop_counter(1, 0)
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
|
||||
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
compute_cost = TrainCycleItem(
|
||||
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
|
||||
)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
# NOTE: if in_place is True, we will not create a new tensor in forward
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=activation_size(input_tensor) * buffer_mem_scale)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=activation_size(input_tensor) * (2 - is_inplace),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=activation_size(input_tensor) * buffer_mem_scale,
|
||||
)
|
||||
|
||||
# temp_mem_scale is for situation like softmax backward
|
||||
# the buffer will be removed during backward phase
|
||||
@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
|
||||
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
|
||||
buffer=0)
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
|
||||
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
|
||||
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
|
@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
|
||||
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['binary_elementwise_meta_info']
|
||||
__all__ = ["binary_elementwise_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(BCAST_FUNC_OP)
|
||||
@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
@@ -1,22 +1,14 @@
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['convnd_meta_info']
|
||||
__all__ = ["convnd_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.Conv1d)
|
||||
@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
|
||||
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
|
||||
bwd_compute_cost = (
|
||||
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
|
||||
if has_bias
|
||||
else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
|
||||
)
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# TODO: use profiler to check conv temp memory
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||
if has_bias
|
||||
else compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
||||
|
||||
# compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
|
||||
[weight_tensor])
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
|
||||
[output_tensor, weight_tensor], [weight_tensor]
|
||||
)
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
||||
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
|
||||
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
|
||||
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
|
||||
)
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
|
||||
|
||||
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
|
||||
|
@@ -1,23 +1,15 @@
|
||||
from functools import reduce
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['linear_meta_info', 'matmul_meta_info']
|
||||
__all__ = ["linear_meta_info", "matmul_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
|
||||
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
|
||||
)
|
||||
bwd_compute_cost = (
|
||||
flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
|
||||
+ flop_mapping[torch.ops.aten.mm.default](
|
||||
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
|
||||
)
|
||||
+ flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||
)
|
||||
compute_cost = TrainCycleItem(
|
||||
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
|
||||
)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
|
||||
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
|
||||
)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensor, weight_tensor], (input_tensor,)
|
||||
) + flop_mapping[torch.ops.aten.mm.default](
|
||||
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
|
||||
)
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
compute_cost = TrainCycleItem(
|
||||
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
|
||||
)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# batched gemv case 1: batched matrix-vector multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
|
||||
)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||
output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors)
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
|
||||
) + flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors,
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# gemv case 2: vector-matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0], input_tensors[0]], output_tensors
|
||||
) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemv case 2: vector-batched matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||
[output_tensors[0].reshape(-1)])
|
||||
[output_tensors[0].reshape(-1)],
|
||||
)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||
output_tensors
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors
|
||||
)
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
|
||||
) + flop_mapping[torch.ops.aten.matmul.default](
|
||||
[
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
|
||||
output_tensors[0].reshape(-1),
|
||||
],
|
||||
output_tensors,
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||
# gemm & batched gemm case 1: batched matrix-matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||
)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||
[input_tensors[1]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||
)
|
||||
[
|
||||
input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
|
||||
output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
|
||||
],
|
||||
[input_tensors[1]],
|
||||
) + flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemm case 2: matrix-batched matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
|
||||
0, 1)
|
||||
], [output_tensors[0].transpose(-2, -1)])
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
|
||||
input_tensors[0].transpose(0, 1),
|
||||
],
|
||||
[output_tensors[0].transpose(-2, -1)],
|
||||
)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||
[input_tensors[0]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||
)
|
||||
[
|
||||
output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
|
||||
],
|
||||
[input_tensors[0]],
|
||||
) + flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
|
||||
compute_size_in_bytes(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(output_tensors),
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
|
||||
)
|
||||
|
||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||
# Batched matrix-batched matrix multiplication
|
||||
# Fetch shape of the two inputs and see if the batch dimensions are the same
|
||||
_is_batch_dims_same = True
|
||||
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
|
||||
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||
for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||
if shape_0 != shape_1:
|
||||
_is_batch_dims_same = False
|
||||
break
|
||||
@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# Case 1: batch dimensions are the same
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
|
||||
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
|
||||
-1, input_dim_10, input_dim_11)
|
||||
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[
|
||||
input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
|
||||
input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
|
||||
],
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
)
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||
)
|
||||
[
|
||||
input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
|
||||
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
|
||||
],
|
||||
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
|
||||
) + flop_mapping[torch.ops.aten.bmm.default](
|
||||
[
|
||||
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
|
||||
],
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
|
||||
@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
else:
|
||||
# Case 2: batch dimensions are different
|
||||
batch_dims = output_tensors[0].shape[:-2]
|
||||
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_00,
|
||||
input_dim_01,
|
||||
device="meta")
|
||||
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_10,
|
||||
input_dim_11,
|
||||
device="meta")
|
||||
extended_input_0 = torch.rand(
|
||||
reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
|
||||
)
|
||||
extended_input_1 = torch.rand(
|
||||
reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
|
||||
)
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
|
||||
)
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[extended_input_1]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||
[extended_input_0]
|
||||
)
|
||||
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[extended_input_1],
|
||||
) + flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||
[extended_input_0],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
|
||||
compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
|
||||
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors)
|
||||
- compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
)
|
||||
|
||||
# compute cost
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# memory cost
|
||||
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
|
||||
|
||||
|
@@ -3,7 +3,7 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
|
@@ -1,22 +1,14 @@
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
|
||||
__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.BatchNorm1d)
|
||||
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
||||
# saved inv std and some other args indicating the status of the module
|
||||
# the bwd outputs are input grad, weight grad and bias grad
|
||||
bwd_in_args = [
|
||||
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
|
||||
output_tensor,
|
||||
output_tensor,
|
||||
weight_tensor,
|
||||
mean_tensor,
|
||||
var_tensor,
|
||||
mean_tensor,
|
||||
var_tensor,
|
||||
1e-5,
|
||||
num_batch,
|
||||
]
|
||||
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
|
||||
|
||||
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
||||
# calculate memory cost
|
||||
# the fwd activation cost is output plus saved mean and saved inv std
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||
[input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||
)
|
||||
|
||||
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
|
||||
# and saved inv std during backward phase
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||
)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
|
||||
fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
|
||||
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
|
||||
running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
|
||||
running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
|
||||
running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
|
||||
running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
|
||||
|
||||
# construct args
|
||||
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
|
||||
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
||||
|
||||
# memory cost
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]),
|
||||
)
|
||||
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([running_mean, running_var]),
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=compute_size_in_bytes([running_mean, running_var]),
|
||||
buffer=compute_size_in_bytes([running_mean, running_var]),
|
||||
)
|
||||
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
|
||||
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
|
||||
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
|
||||
fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = []
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
|
||||
|
||||
# temp memory for backward is the index matrix to be discarded
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
|
||||
temp=compute_size_in_bytes(index_matrix))
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
|
||||
temp=compute_size_in_bytes(index_matrix),
|
||||
)
|
||||
|
||||
# total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
|
||||
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
||||
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
|
||||
fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
|
||||
@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
|
||||
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
total_mem_cost = MemoryCost(
|
||||
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
|
||||
|
||||
# register torch.Tensor related metainfo
|
||||
# (0, 0)
|
||||
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
|
||||
torch.arange])(tensor_related_metainfo(0, 0))
|
||||
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
|
||||
tensor_related_metainfo(0, 0)
|
||||
)
|
||||
|
||||
# (1, 0)
|
||||
meta_register.register([
|
||||
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
|
||||
torch.Tensor.split, torch.split, torch.Tensor.view
|
||||
])(tensor_related_metainfo(1, 0))
|
||||
meta_register.register(
|
||||
[
|
||||
torch.Tensor.flatten,
|
||||
torch.flatten,
|
||||
torch.Tensor.transpose,
|
||||
torch.transpose,
|
||||
torch.Tensor.permute,
|
||||
torch.permute,
|
||||
torch.Tensor.split,
|
||||
torch.split,
|
||||
torch.Tensor.view,
|
||||
]
|
||||
)(tensor_related_metainfo(1, 0))
|
||||
|
||||
# (1, 1)
|
||||
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
|
||||
|
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
|
||||
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
|
||||
parameter=0,
|
||||
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
|
||||
activation_size([x_tensor, y_tensor]),
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
|
||||
parameter=0,
|
||||
temp=activation_size([output_tensor]) * 3
|
||||
+ activation_size([condition_tensor])
|
||||
- activation_size([x_tensor, y_tensor]),
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
total_mem_cost = MemoryCost(
|
||||
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
|
@@ -1,14 +1,12 @@
|
||||
__all__ = ['Registry']
|
||||
__all__ = ["Registry"]
|
||||
|
||||
|
||||
class Registry:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
if isinstance(source, (list, tuple)):
|
||||
# support register a list of items for this func
|
||||
@@ -21,7 +19,7 @@ class Registry:
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||
assert source in self.store, f"{source} not found in the {self.name} registry"
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
@@ -29,4 +27,4 @@ class Registry:
|
||||
return source in self.store
|
||||
|
||||
|
||||
meta_register = Registry('meta')
|
||||
meta_register = Registry("meta")
|
||||
|
@@ -2,20 +2,13 @@ from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['ShardMetaInfo']
|
||||
__all__ = ["ShardMetaInfo"]
|
||||
|
||||
|
||||
class ShardMetaInfo:
|
||||
@@ -76,10 +69,12 @@ class ShardMetaInfo:
|
||||
"""
|
||||
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
op_data = OperationData(name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
op_data = OperationData(
|
||||
name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape,
|
||||
)
|
||||
elif isinstance(sharding_spec, (list, tuple)):
|
||||
data = operation_data.data
|
||||
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
|
||||
@@ -97,8 +92,9 @@ class ShardMetaInfo:
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
|
||||
f"Meta info for {self._target} is not registered."
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(
|
||||
self._target
|
||||
), f"Meta info for {self._target} is not registered."
|
||||
if meta_register.has(self._target.__class__):
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
@@ -117,11 +113,11 @@ class ShardMetaInfo:
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
kwargs = {'inplace': self.target.inplace}
|
||||
kwargs = {"inplace": self.target.inplace}
|
||||
elif self.target in INPLACE_OPS:
|
||||
kwargs = {'inplace': True}
|
||||
kwargs = {"inplace": True}
|
||||
else:
|
||||
kwargs = {'inplace': False}
|
||||
kwargs = {"inplace": False}
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user