mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[autoparallel] Patch meta information of torch.matmul
(#2584)
* [autoparallel] matmul metainfo * [auto_parallel] remove unused print * [tests] skip test_matmul_handler when torch version is lower than 1.12.0
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from functools import reduce
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -16,7 +17,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['linear_meta_info']
|
||||
__all__ = ['linear_meta_info', 'matmul_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.functional.linear)
|
||||
@@ -170,3 +171,235 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
||||
|
||||
@meta_register.register(torch.matmul)
|
||||
def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.matmul meta info generator
|
||||
There are several cases for torch.matmul:
|
||||
1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
|
||||
as two input vectors.
|
||||
2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
|
||||
phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
|
||||
the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
|
||||
the same size as the input matrix, and allocate memory for the gradient of two inputs.
|
||||
3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
|
||||
output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
|
||||
the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
|
||||
the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
|
||||
allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
|
||||
matrix will be stored in the memory allocated during the forward phase.
|
||||
3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
|
||||
4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
|
||||
inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
|
||||
memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
|
||||
will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
|
||||
5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
|
||||
of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
|
||||
memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
|
||||
two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
|
||||
|
||||
"""
|
||||
# Get input and output tensors
|
||||
input_tensors = [args[0].data, args[1].data]
|
||||
output_tensors = [args[-1].data]
|
||||
|
||||
# Check dimension
|
||||
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
||||
# Dot
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
||||
# gemv case 1: matrix-vector multiplication
|
||||
# &
|
||||
# batched gemv case 1: batched matrix-vector multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||
output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
||||
# gemv case 2: vector-matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemv case 2: vector-batched matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||
[output_tensors[0].reshape(-1)])
|
||||
|
||||
# combine the dimensions of output
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||
output_tensors
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mv.default](
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||
output_tensors
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]),
|
||||
buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||
# gemm & batched gemm case 1: batched matrix-matrix multiplication
|
||||
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
|
||||
[input_tensors[1]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
|
||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
||||
|
||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||
# batched gemm case 2: matrix-batched matrix multiplication
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
|
||||
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
|
||||
0, 1)
|
||||
], [output_tensors[0].transpose(-2, -1)])
|
||||
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
|
||||
[input_tensors[0]]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.mm.default](
|
||||
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
|
||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
|
||||
temp=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
||||
parameter=0,
|
||||
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
|
||||
|
||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||
# Batched matrix-batched matrix multiplication
|
||||
# Fetch shape of the two inputs and see if the batch dimensions are the same
|
||||
_is_batch_dims_same = True
|
||||
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
|
||||
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
|
||||
if shape_0 != shape_1:
|
||||
_is_batch_dims_same = False
|
||||
break
|
||||
else:
|
||||
_is_batch_dims_same = False
|
||||
|
||||
# retireve dimensions
|
||||
input_dim_00 = input_tensors[0].shape[-2]
|
||||
input_dim_01 = input_tensors[0].shape[-1]
|
||||
input_dim_10 = input_tensors[1].shape[-2]
|
||||
input_dim_11 = input_tensors[1].shape[-1]
|
||||
output_dim_0 = output_tensors[0].shape[-2]
|
||||
output_dim_1 = output_tensors[0].shape[-1]
|
||||
|
||||
if _is_batch_dims_same:
|
||||
# Case 1: batch dimensions are the same
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
|
||||
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
|
||||
-1, input_dim_10, input_dim_11)
|
||||
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
|
||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
|
||||
|
||||
else:
|
||||
# Case 2: batch dimensions are different
|
||||
batch_dims = output_tensors[0].shape[:-2]
|
||||
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_00,
|
||||
input_dim_01,
|
||||
device="meta")
|
||||
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
|
||||
input_dim_10,
|
||||
input_dim_11,
|
||||
device="meta")
|
||||
|
||||
# Forward compute cost: C = A * B
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
|
||||
|
||||
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
|
||||
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
|
||||
[extended_input_1]
|
||||
) + \
|
||||
flop_mapping[torch.ops.aten.bmm.default](
|
||||
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
|
||||
[extended_input_0]
|
||||
)
|
||||
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
|
||||
activation_size([extended_input_0, extended_input_1]),
|
||||
temp=activation_size([extended_input_0, extended_input_1]))
|
||||
|
||||
# compute cost
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# memory cost
|
||||
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = input_tensors
|
||||
fwd_buffer = []
|
||||
fwd_out = output_tensors
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
||||
|
@@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (
|
||||
BatchedMatMulStrategyGenerator,
|
||||
@@ -326,7 +326,7 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
|
||||
|
||||
@operator_registry.register(torch.matmul)
|
||||
@operator_registry.register(torch.Tensor.matmul)
|
||||
class MatMulHandler(NodeHandler):
|
||||
class MatMulHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
|
||||
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
|
||||
|
@@ -16,6 +16,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
||||
from .strategy import StrategyGenerator
|
||||
@@ -266,6 +267,10 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'The target function {target} is not patched yet, ')
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
@@ -317,4 +322,8 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||
# attach metainfos to the handler
|
||||
setattr(self, "metainfo_vector", metainfo_vector)
|
||||
|
||||
else:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'The target function {target} is not patched yet')
|
||||
|
||||
return self.strategies_vector
|
||||
|
Reference in New Issue
Block a user