mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,23 +1,15 @@
|
||||
from functools import reduce
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['linear_meta_info', 'matmul_meta_info']
|
||||
__all__ = ["linear_meta_info", "matmul_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
|
||||
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
|
||||
)
|
||||
bwd_compute_cost = (
|
||||
flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
|
||||
+ flop_mapping[torch.ops.aten.mm.default](
|
||||
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
|
||||
)
|
||||
+ flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||
)
|
||||
compute_cost = TrainCycleItem(
|
||||
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
|
||||
)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
|
||||
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
|
||||
)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensor, weight_tensor], (input_tensor,)
|
||||
) + flop_mapping[torch.ops.aten.mm.default](
|
||||
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
|
||||
)
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
compute_cost = TrainCycleItem(
|
||||
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
|
||||
)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
fwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
bwd_memory_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||
parameter=compute_size_in_bytes(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# batched gemv case 1: batched matrix-vector multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
|
||||
)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||
output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors)
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
|
||||
) + flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors,
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# gemv case 2: vector-matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0], input_tensors[0]], output_tensors
|
||||
) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemv case 2: vector-batched matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||
[output_tensors[0].reshape(-1)])
|
||||
[output_tensors[0].reshape(-1)],
|
||||
)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||
output_tensors
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.matmul.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors
|
||||
)
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
|
||||
) + flop_mapping[torch.ops.aten.matmul.default](
|
||||
[
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
|
||||
output_tensors[0].reshape(-1),
|
||||
],
|
||||
output_tensors,
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]),
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||
# gemm & batched gemm case 1: batched matrix-matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||
)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||
[input_tensors[1]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||
)
|
||||
[
|
||||
input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
|
||||
output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
|
||||
],
|
||||
[input_tensors[1]],
|
||||
) + flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemm case 2: matrix-batched matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
|
||||
0, 1)
|
||||
], [output_tensors[0].transpose(-2, -1)])
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
|
||||
input_tensors[0].transpose(0, 1),
|
||||
],
|
||||
[output_tensors[0].transpose(-2, -1)],
|
||||
)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||
[input_tensors[0]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||
)
|
||||
[
|
||||
output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
|
||||
],
|
||||
[input_tensors[0]],
|
||||
) + flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
|
||||
compute_size_in_bytes(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
|
||||
temp=compute_size_in_bytes(output_tensors),
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
|
||||
)
|
||||
|
||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||
# Batched matrix-batched matrix multiplication
|
||||
# Fetch shape of the two inputs and see if the batch dimensions are the same
|
||||
_is_batch_dims_same = True
|
||||
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
|
||||
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||
for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||
if shape_0 != shape_1:
|
||||
_is_batch_dims_same = False
|
||||
break
|
||||
@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
# Case 1: batch dimensions are the same
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
|
||||
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
|
||||
-1, input_dim_10, input_dim_11)
|
||||
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[
|
||||
input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
|
||||
input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
|
||||
],
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
)
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||
)
|
||||
[
|
||||
input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
|
||||
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
|
||||
],
|
||||
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
|
||||
) + flop_mapping[torch.ops.aten.bmm.default](
|
||||
[
|
||||
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
|
||||
],
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
|
||||
@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
else:
|
||||
# Case 2: batch dimensions are different
|
||||
batch_dims = output_tensors[0].shape[:-2]
|
||||
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_00,
|
||||
input_dim_01,
|
||||
device="meta")
|
||||
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_10,
|
||||
input_dim_11,
|
||||
device="meta")
|
||||
extended_input_0 = torch.rand(
|
||||
reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
|
||||
)
|
||||
extended_input_1 = torch.rand(
|
||||
reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
|
||||
)
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
|
||||
)
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[extended_input_1]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||
[extended_input_0]
|
||||
)
|
||||
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[extended_input_1],
|
||||
) + flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||
[extended_input_0],
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
|
||||
compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
|
||||
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=compute_size_in_bytes(input_tensors)
|
||||
- compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||
)
|
||||
|
||||
# compute cost
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# memory cost
|
||||
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
total_cost = MemoryCost(
|
||||
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
|
||||
|
||||
|
Reference in New Issue
Block a user