[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,8 +3,6 @@ import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace module
INPLACE_MODULE = [nn.ReLU]

View File

@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
lambda x:
(x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
args)).data
lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
and x.name != "softmax_dim",
args,
)
).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = 1 if kwargs.get('inplace', False) else 0
is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
parameter=0,
temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale)
fwd_memory_cost = MemoryCost(
activation=activation_size(input_tensor) * (2 - is_inplace),
parameter=0,
temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale,
)
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
buffer=0)
buffer=0,
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
__all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP)
@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@@ -1,22 +1,14 @@
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ['convnd_meta_info']
__all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d)
@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
bwd_compute_cost = (
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
if has_bias
else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
[weight_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
[output_tensor, weight_tensor], [weight_tensor]
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)

View File

@@ -1,23 +1,15 @@
from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
__all__ = ['linear_meta_info', 'matmul_meta_info']
__all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear)
@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
)
bwd_compute_cost = (
flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
+ flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
+ flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0,
)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0,
)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensor, weight_tensor], (input_tensor,)
) + flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
[output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0], input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0,
)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
[output_tensors[0].reshape(-1)],
)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
)
[output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default](
[
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
output_tensors[0].reshape(-1),
],
output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0,
)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
[input_tensors[1]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
)
[
input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
],
[input_tensors[1]],
) + flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
0, 1)
], [output_tensors[0].transpose(-2, -1)])
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
input_tensors[0].transpose(0, 1),
],
[output_tensors[0].transpose(-2, -1)],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
[input_tensors[0]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
)
[
output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
],
[input_tensors[0]],
) + flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors),
)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
)
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
-1, input_dim_10, input_dim_11)
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[
input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
],
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
)
[
input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
) + flop_mapping[torch.ops.aten.bmm.default](
[
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_00,
input_dim_01,
device="meta")
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_10,
input_dim_11,
device="meta")
extended_input_0 = torch.rand(
reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
)
extended_input_1 = torch.rand(
reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
)
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0]
)
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1],
) + flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0],
)
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors)
- compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
)
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
total_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)

View File

@@ -3,7 +3,7 @@ from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register

View File

@@ -1,22 +1,14 @@
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
@meta_register.register(torch.nn.BatchNorm1d)
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
output_tensor,
output_tensor,
weight_tensor,
mean_tensor,
var_tensor,
mean_tensor,
var_tensor,
1e-5,
num_batch,
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
)
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
# construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([running_mean, running_var]))
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=compute_size_in_bytes([running_mean, running_var]),
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([running_mean, running_var]),
buffer=compute_size_in_bytes([running_mean, running_var]))
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([running_mean, running_var]),
buffer=compute_size_in_bytes([running_mean, running_var]),
)
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out

View File

@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix))
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix),
)
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out

View File

@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
parameter=0,
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
buffer=0)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
parameter=0,
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
buffer=0,
)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# register torch.Tensor related metainfo
# (0, 0)
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
torch.arange])(tensor_related_metainfo(0, 0))
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
tensor_related_metainfo(0, 0)
)
# (1, 0)
meta_register.register([
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
torch.Tensor.split, torch.split, torch.Tensor.view
])(tensor_related_metainfo(1, 0))
meta_register.register(
[
torch.Tensor.flatten,
torch.flatten,
torch.Tensor.transpose,
torch.transpose,
torch.Tensor.permute,
torch.permute,
torch.Tensor.split,
torch.split,
torch.Tensor.view,
]
)(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))

View File

@@ -4,7 +4,7 @@ import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
parameter=0,
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
activation_size([x_tensor, y_tensor]),
buffer=0)
bwd_mem_cost = MemoryCost(
activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
parameter=0,
temp=activation_size([output_tensor]) * 3
+ activation_size([condition_tensor])
- activation_size([x_tensor, y_tensor]),
buffer=0,
)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)

View File

@@ -1,14 +1,12 @@
__all__ = ['Registry']
__all__ = ["Registry"]
class Registry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -21,7 +19,7 @@ class Registry:
return wrapper
def get(self, source):
assert source in self.store, f'{source} not found in the {self.name} registry'
assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -29,4 +27,4 @@ class Registry:
return source in self.store
meta_register = Registry('meta')
meta_register = Registry("meta")

View File

@@ -2,20 +2,13 @@ from typing import Callable, List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['ShardMetaInfo']
__all__ = ["ShardMetaInfo"]
class ShardMetaInfo:
@@ -76,10 +69,12 @@ class ShardMetaInfo:
"""
if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
op_data = OperationData(
name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape,
)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
@@ -97,8 +92,9 @@ class ShardMetaInfo:
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
assert meta_register.has(self._target.__class__) or meta_register.has(
self._target
), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
kwargs = {"inplace": True}
else:
kwargs = {'inplace': False}
kwargs = {"inplace": False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)