mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-18 21:29:28 +00:00
[fx] Add linear metainfo class for auto parallel (#1783)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel
This commit is contained in:
3
colossalai/auto_parallel/meta_profiler/__init__.py
Normal file
3
colossalai/auto_parallel/meta_profiler/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .meta_registry import *
|
||||
from .metainfo import *
|
||||
from .registry import meta_register
|
@@ -0,0 +1 @@
|
||||
from .linear import *
|
157
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
Normal file
157
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ['linear_meta_info']
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.Linear)
|
||||
def linear_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Linear meta info generator
|
||||
The atens graph of torch.nn.Linear with bias is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {})
|
||||
%zeros_like_default : [#users=3] = call_function[target=torch.ops.aten.zeros_like.default](args = (%addmm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
|
||||
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
|
||||
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
|
||||
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
|
||||
%sum_dim_int_list : [#users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%zeros_like_default, [None], None), kwargs = {})
|
||||
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_dim_int_list, [None]), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%view_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default,), kwargs = {})
|
||||
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
|
||||
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
|
||||
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
|
||||
|
||||
The one without bias is
|
||||
graph():
|
||||
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
|
||||
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%input_2, None), kwargs = {})
|
||||
%zeros_like_default : [#users=2] = call_function[target=torch.ops.aten.zeros_like.default](args = (%mm_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
|
||||
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%zeros_like_default,), kwargs = {})
|
||||
%mm_default_1 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%t_default, %detach_default), kwargs = {})
|
||||
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%mm_default_1,), kwargs = {})
|
||||
%mm_default_2 : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%zeros_like_default, None), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%mm_default_2,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%t_default_1,), kwargs = {})
|
||||
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%t_default_2,), kwargs = {})
|
||||
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and save input flag
|
||||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
|
||||
|
||||
# process the dimension of input and output
|
||||
if len(input_tensor.shape) > 2:
|
||||
input_tensor: torch.Tensor
|
||||
input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
|
||||
|
||||
if len(output_tensor.shape) > 2:
|
||||
output_tensor: torch.Tensor
|
||||
output_tensor = output_tensor.view(-1, output_tensor.shape[-1])
|
||||
|
||||
if len(args) == 4:
|
||||
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
|
||||
has_bias = True
|
||||
|
||||
if has_bias:
|
||||
# calculate cost with bias
|
||||
# the fwd op with compute cost is addmm
|
||||
# the bwd op with compute cost is mm * 2 and sum.dim_IntList
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
|
||||
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor),
|
||||
parameter=activation_size(weight_tensor) + activation_size(bias_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) + activation_size(weight_tensor) +
|
||||
activation_size(bias_tensor),
|
||||
parameter=activation_size(weight_tensor) + activation_size(bias_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
else:
|
||||
# calculate cost without bias
|
||||
# the fwd op with compute cost is mm
|
||||
# the bwd op with compute cost is mm * 2
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
|
||||
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
|
||||
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||
bwd=bwd_compute_cost,
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor),
|
||||
parameter=activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) + activation_size(weight_tensor),
|
||||
parameter=activation_size(weight_tensor),
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
# total cost is to sum the forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
101
colossalai/auto_parallel/meta_profiler/metainfo.py
Normal file
101
colossalai/auto_parallel/meta_profiler/metainfo.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['MetaInfo']
|
||||
|
||||
|
||||
class MetaInfo:
|
||||
"""MetaInfo class
|
||||
This class is used to store meta info based on sharding strategy and the given
|
||||
target function.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
|
||||
# compute cost of forward and backward computation
|
||||
self.compute_cost: TrainCycleItem
|
||||
|
||||
# compute memory cost of forward and backward phase
|
||||
self.memory_cost: TrainCycleItem
|
||||
|
||||
# list of input tensors
|
||||
self.fwd_in: list[OperationData]
|
||||
|
||||
# sharding strategy
|
||||
self._strategy = strategy
|
||||
|
||||
# target function
|
||||
self._target = target
|
||||
|
||||
# compute metainfo if possible
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
@property
|
||||
def strategy(self) -> ShardingStrategy:
|
||||
return self._strategy
|
||||
|
||||
@property
|
||||
def target(self) -> Callable:
|
||||
return self._target
|
||||
|
||||
@strategy.setter
|
||||
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
@target.setter
|
||||
def target(self, target: Callable) -> None:
|
||||
self._target = target
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""
|
||||
Compute sharded meta tensor based on the given data and sharding spec.
|
||||
"""
|
||||
shard_sequnce = sharding_spec.sharding_sequence
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
shape = operation_data.data.shape
|
||||
|
||||
new_shape = []
|
||||
for dim, shard in zip(shape, shard_sequnce):
|
||||
if shard.is_replica:
|
||||
# replica
|
||||
new_shape.append(dim)
|
||||
else:
|
||||
# sharded according to device_mesh shape
|
||||
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
|
||||
|
||||
return OperationData(name=operation_data.name,
|
||||
data=torch.zeros(new_shape, device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
|
||||
def compute_metainfo(self):
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
|
||||
assert meta_register.has(self._target), f'{self._target} not found in the meta registry'
|
||||
meta_func = meta_register.get(self._target)
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args)
|
32
colossalai/auto_parallel/meta_profiler/registry.py
Normal file
32
colossalai/auto_parallel/meta_profiler/registry.py
Normal file
@@ -0,0 +1,32 @@
|
||||
__all__ = ['Registry']
|
||||
|
||||
|
||||
class Registry:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
if isinstance(source, (list, tuple)):
|
||||
# support register a list of items for this func
|
||||
for element in source:
|
||||
self.store[element] = func
|
||||
else:
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
def has(self, source):
|
||||
return source in self.store
|
||||
|
||||
|
||||
meta_register = Registry('meta')
|
Reference in New Issue
Block a user