mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
parent
573af84184
commit
ffcdbf0f65
@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|||||||
# Inputs contains the shapes of two matrices.
|
# Inputs contains the shapes of two matrices.
|
||||||
input_shapes = [v.shape for v in inputs]
|
input_shapes = [v.shape for v in inputs]
|
||||||
assert len(input_shapes) == 2, input_shapes
|
assert len(input_shapes) == 2, input_shapes
|
||||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
|
||||||
|
# There are three cases: 1) gemm, 2) gemv, 3) dot
|
||||||
|
if all(len(shape) == 2 for shape in input_shapes):
|
||||||
|
# gemm
|
||||||
|
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||||
|
elif all(len(shape) == 1 for shape in input_shapes):
|
||||||
|
# dot
|
||||||
|
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
|
||||||
|
|
||||||
|
# expand shape
|
||||||
|
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
|
||||||
|
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
|
||||||
|
else:
|
||||||
|
# gemv
|
||||||
|
if len(input_shapes[0]) == 1:
|
||||||
|
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
|
||||||
|
input_shapes.reverse()
|
||||||
|
else:
|
||||||
|
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
|
||||||
|
|
||||||
|
# expand the shape of the vector to [batch size, 1]
|
||||||
|
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
|
||||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.fx.graph import CodeGen
|
||||||
|
except:
|
||||||
|
pass
|
||||||
from torch.fx.graph import (
|
from torch.fx.graph import (
|
||||||
CodeGen,
|
|
||||||
PythonCode,
|
PythonCode,
|
||||||
_custom_builtins,
|
_custom_builtins,
|
||||||
_format_target,
|
_format_target,
|
||||||
@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
|||||||
"""
|
"""
|
||||||
Check if the node could end the ckpt region at `ckpt_level`
|
Check if the node could end the ckpt region at `ckpt_level`
|
||||||
"""
|
"""
|
||||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||||
return node.meta['info'].to_recompute[ckpt_level] is not None
|
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
|||||||
current_region = None
|
current_region = None
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||||
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
|
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||||
|
|
||||||
# this activation checkpoint label is not set yet
|
# this activation checkpoint label is not set yet
|
||||||
# meaning this is the first node of the activation ckpt region
|
# meaning this is the first node of the activation ckpt region
|
||||||
@ -152,12 +156,12 @@ def emit_ckpt_func(body,
|
|||||||
|
|
||||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||||
# the label will be '0_1_1'
|
# the label will be '0_1_1'
|
||||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
|
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
|
||||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||||
|
|
||||||
# if there is more level to fetch
|
# if there is more level to fetch
|
||||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
|
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
|
||||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||||
start_idx = [item[0] for item in ckpt_regions]
|
start_idx = [item[0] for item in ckpt_regions]
|
||||||
end_idx = [item[1] for item in ckpt_regions]
|
end_idx = [item[1] for item in ckpt_regions]
|
||||||
@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||||
start_idx = [item[0] for item in ckpt_regions]
|
start_idx = [item[0] for item in ckpt_regions]
|
||||||
end_idx = [item[1] for item in ckpt_regions]
|
end_idx = [item[1] for item in ckpt_regions]
|
||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
|
@ -112,7 +112,7 @@ class MetaInfo:
|
|||||||
|
|
||||||
# should keep the same whenever manipulated
|
# should keep the same whenever manipulated
|
||||||
# ============================= Invariant ==================================
|
# ============================= Invariant ==================================
|
||||||
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||||
to_offload: Optional[bool] = False
|
to_offload: Optional[bool] = False
|
||||||
sharding_spec: str = 'RR'
|
sharding_spec: str = 'RR'
|
||||||
|
|
||||||
|
@ -237,7 +237,14 @@ class ShapeProp(torch.fx.Interpreter):
|
|||||||
Returns:
|
Returns:
|
||||||
Any: The value returned from executing the Module
|
Any: The value returned from executing the Module
|
||||||
"""
|
"""
|
||||||
wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
|
||||||
|
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||||
|
def wrap_fn(elem, device=device):
|
||||||
|
if isinstance(elem, torch.Tensor):
|
||||||
|
return MetaTensor(elem, device=device)
|
||||||
|
else:
|
||||||
|
return elem
|
||||||
|
|
||||||
with self._mode:
|
with self._mode:
|
||||||
return super().run(*tree_map(wrap_fn, args))
|
return super().run(*tree_map(wrap_fn, args))
|
||||||
|
|
||||||
|
@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
|
|||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||||
def conv1d_impl(input, weight, **kwargs):
|
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||||
bias = getattr(kwargs, 'bias', None)
|
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return F.conv1d(input, weight, **kwargs)
|
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs
|
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||||
new_kwargs['bias'] = None
|
(-1, 1))
|
||||||
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||||
def conv2d_impl(input, weight, **kwargs):
|
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||||
bias = getattr(kwargs, 'bias', None)
|
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return F.conv2d(input, weight, **kwargs)
|
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs
|
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||||
new_kwargs['bias'] = None
|
(-1, 1, 1))
|
||||||
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||||
def conv3d_impl(input, weight, **kwargs):
|
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||||
bias = getattr(kwargs, 'bias', None)
|
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return F.conv3d(input, weight, **kwargs)
|
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs
|
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||||
new_kwargs['bias'] = None
|
(-1, 1, 1, 1))
|
||||||
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
|
||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||||
def conv_transpose1d_impl(input, weight, **kwargs):
|
def conv_transpose1d_impl(input,
|
||||||
bias = getattr(kwargs, 'bias', None)
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=_single(1),
|
||||||
|
padding=_single(0),
|
||||||
|
output_padding=_single(0),
|
||||||
|
groups=1,
|
||||||
|
dilation=_single(1)):
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return F.conv_transpose1d(input, weight, **kwargs)
|
return F.conv_transpose1d(input,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation)
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs
|
return F.conv_transpose1d(input,
|
||||||
new_kwargs['bias'] = None
|
weight,
|
||||||
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation) + bias.reshape((-1, 1))
|
||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||||
def conv_transpose2d_impl(input, weight, **kwargs):
|
def conv_transpose2d_impl(input,
|
||||||
bias = getattr(kwargs, 'bias', None)
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=_pair(1),
|
||||||
|
padding=_pair(0),
|
||||||
|
output_padding=_pair(0),
|
||||||
|
groups=1,
|
||||||
|
dilation=_pair(1)):
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return F.conv_transpose2d(input, weight, **kwargs)
|
return F.conv_transpose2d(input,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation)
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs
|
return F.conv_transpose2d(input,
|
||||||
new_kwargs['bias'] = None
|
weight,
|
||||||
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||||
def conv_transpose3d_impl(input, weight, **kwargs):
|
def conv_transpose3d_impl(input,
|
||||||
bias = getattr(kwargs, 'bias', None)
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=_triple(1),
|
||||||
|
padding=_triple(0),
|
||||||
|
output_padding=_triple(0),
|
||||||
|
groups=1,
|
||||||
|
dilation=_triple(1)):
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return F.conv_transpose3d(input, weight, **kwargs)
|
return F.conv_transpose3d(input,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation)
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs
|
return F.conv_transpose3d(input,
|
||||||
new_kwargs['bias'] = None
|
weight,
|
||||||
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||||
|
@ -155,7 +155,7 @@ class ColoTracer(Tracer):
|
|||||||
|
|
||||||
def create_node(self, *args, **kwargs) -> Node:
|
def create_node(self, *args, **kwargs) -> Node:
|
||||||
node = super().create_node(*args, **kwargs)
|
node = super().create_node(*args, **kwargs)
|
||||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
|
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def trace(self,
|
def trace(self,
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from .meta_registry import *
|
from .meta_registry import *
|
||||||
from .metainfo import *
|
|
||||||
from .registry import meta_register
|
from .registry import meta_register
|
||||||
|
from .shard_metainfo import *
|
||||||
|
@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import elementwise_flop_counter
|
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
|
|
||||||
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||||||
"""Meta information generator for binary elementwise operations
|
"""Meta information generator for binary elementwise operations
|
||||||
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
|
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
|
||||||
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
|
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
|
||||||
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
|
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
|
||||||
this behavior, it is critical for better memory estimation.
|
this behavior, it is critical for better memory estimation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
MemoryCost,
|
MemoryCost,
|
||||||
OperationData,
|
OperationData,
|
||||||
@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||||||
StrategiesVector,
|
StrategiesVector,
|
||||||
TrainCycleItem,
|
TrainCycleItem,
|
||||||
)
|
)
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
# calculate memory cost
|
# calculate memory cost
|
||||||
# TODO: use profiler to check conv temp memory
|
# TODO: use profiler to check conv temp memory
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_memory_cost = MemoryCost(
|
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||||
activation=activation_size([input_tensor, output_tensor]),
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
bwd_memory_cost = MemoryCost(
|
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
|
||||||
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
|
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||||
if has_bias else activation_size([input_tensor, weight_tensor]),
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
|
if has_bias else compute_size_in_bytes(weight_tensor),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
# total cost is the sum of forward and backward cost
|
# total cost is the sum of forward and backward cost
|
||||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||||
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
|||||||
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
|
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
|
||||||
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
|
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
|
||||||
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
|
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
|
||||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||||
parameter=0,
|
parameter=0,
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
|
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
|
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
MemoryCost,
|
MemoryCost,
|
||||||
OperationData,
|
OperationData,
|
||||||
@ -11,8 +13,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||||||
StrategiesVector,
|
StrategiesVector,
|
||||||
TrainCycleItem,
|
TrainCycleItem,
|
||||||
)
|
)
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
|
||||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
# NOTE: Linear don't have buffer and temp in forward and backward phase
|
||||||
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
|
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
|
||||||
parameter=activation_size(weight_tensor),
|
parameter=compute_size_in_bytes(weight_tensor),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
|
||||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
|
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
|
||||||
parameter=activation_size(weight_tensor),
|
parameter=compute_size_in_bytes(weight_tensor),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
# Check dimension
|
# Check dimension
|
||||||
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
if all(len(tensor.shape) == 1 for tensor in input_tensors):
|
||||||
# Dot
|
# Dot
|
||||||
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
|
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
|
||||||
# gemv case 1: matrix-vector multiplication
|
# gemv case 1: matrix-vector multiplication
|
||||||
# &
|
# &
|
||||||
# batched gemv case 1: batched matrix-vector multiplication
|
# batched gemv case 1: batched matrix-vector multiplication
|
||||||
|
|
||||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
|
||||||
|
|
||||||
# combine the dimensions of output
|
# combine the dimensions of output
|
||||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
|
||||||
[output_tensors[0].reshape(-1), input_tensors[1]],
|
[output_tensors[0].reshape(-1), input_tensors[1]],
|
||||||
output_tensors) + \
|
output_tensors) + \
|
||||||
flop_mapping[torch.ops.aten.mv.default](
|
flop_mapping[torch.ops.aten.matmul.default](
|
||||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||||
output_tensors)
|
output_tensors)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
|
||||||
# gemv case 2: vector-matrix multiplication
|
# gemv case 2: vector-matrix multiplication
|
||||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
|
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
|
||||||
|
|
||||||
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
|
||||||
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
|
||||||
parameter=0,
|
parameter=0,
|
||||||
temp=activation_size(input_tensors[1]),
|
temp=compute_size_in_bytes(input_tensors[1]),
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
|
||||||
# batched gemv case 2: vector-batched matrix multiplication
|
# batched gemv case 2: vector-batched matrix multiplication
|
||||||
|
|
||||||
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
|
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
|
||||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
|
||||||
[output_tensors[0].reshape(-1)])
|
[output_tensors[0].reshape(-1)])
|
||||||
|
|
||||||
@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
[output_tensors[0].reshape(-1), input_tensors[0]],
|
[output_tensors[0].reshape(-1), input_tensors[0]],
|
||||||
output_tensors
|
output_tensors
|
||||||
) + \
|
) + \
|
||||||
flop_mapping[torch.ops.aten.mv.default](
|
flop_mapping[torch.ops.aten.matmul.default](
|
||||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
|
||||||
output_tensors
|
output_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||||
parameter=0,
|
parameter=0,
|
||||||
temp=activation_size(input_tensors[1]),
|
temp=compute_size_in_bytes(input_tensors[1]),
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
|
||||||
@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
|
||||||
)
|
)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
|
||||||
# batched gemm case 2: matrix-batched matrix multiplication
|
# batched gemm case 2: matrix-batched matrix multiplication
|
||||||
@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
|
||||||
)
|
)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
|
||||||
temp=activation_size(output_tensors))
|
compute_size_in_bytes(input_tensors[1]),
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
|
temp=compute_size_in_bytes(output_tensors))
|
||||||
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
|
||||||
parameter=0,
|
parameter=0,
|
||||||
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
|
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
|
||||||
|
|
||||||
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
|
||||||
# Batched matrix-batched matrix multiplication
|
# Batched matrix-batched matrix multiplication
|
||||||
@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
|
||||||
)
|
)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Case 2: batch dimensions are different
|
# Case 2: batch dimensions are different
|
||||||
@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||||||
)
|
)
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(
|
fwd_mem_cost = MemoryCost(
|
||||||
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
|
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
|
||||||
activation_size([extended_input_0, extended_input_1]),
|
compute_size_in_bytes([extended_input_0, extended_input_1]),
|
||||||
temp=activation_size([extended_input_0, extended_input_1]))
|
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
|
||||||
|
|
||||||
# compute cost
|
# compute cost
|
||||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
@ -4,8 +4,6 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
MemoryCost,
|
MemoryCost,
|
||||||
OperationData,
|
OperationData,
|
||||||
@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||||||
StrategiesVector,
|
StrategiesVector,
|
||||||
TrainCycleItem,
|
TrainCycleItem,
|
||||||
)
|
)
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
@ -77,17 +77,18 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
|||||||
# calculate memory cost
|
# calculate memory cost
|
||||||
# the fwd activation cost is output plus saved mean and saved inv std
|
# the fwd activation cost is output plus saved mean and saved inv std
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
|
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
[input_tensor, output_tensor, mean_tensor, var_tensor]),
|
||||||
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=activation_size([mean_tensor, var_tensor]))
|
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||||
|
|
||||||
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
|
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
|
||||||
# and saved inv std during backward phase
|
# and saved inv std during backward phase
|
||||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
|
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||||
temp=activation_size([mean_tensor, var_tensor]),
|
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
|
||||||
buffer=activation_size([mean_tensor, var_tensor]))
|
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
|
||||||
|
|
||||||
# total cost is the sum of forward and backward cost
|
# total cost is the sum of forward and backward cost
|
||||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||||
@ -131,15 +132,16 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
|
|||||||
|
|
||||||
# memory cost
|
# memory cost
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
|
||||||
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||||
temp=0,
|
temp=0,
|
||||||
buffer=activation_size([running_mean, running_var]))
|
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||||
|
|
||||||
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
|
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
|
||||||
parameter=activation_size([weight_tensor, bias_tensor]),
|
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
|
||||||
temp=activation_size([running_mean, running_var]),
|
temp=compute_size_in_bytes([running_mean, running_var]),
|
||||||
buffer=activation_size([running_mean, running_var]))
|
buffer=compute_size_in_bytes([running_mean, running_var]))
|
||||||
|
|
||||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
|
||||||
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
@ -52,8 +52,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
|
||||||
# calculate memory cost
|
# calculate memory cost
|
||||||
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
|
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor))
|
||||||
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
|
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor))
|
||||||
|
|
||||||
# total cost
|
# total cost
|
||||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
|
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
|
||||||
@ -114,11 +114,11 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||||||
# calculate memory cost
|
# calculate memory cost
|
||||||
# NOTE: the index matrix will be discarded in backward phase
|
# NOTE: the index matrix will be discarded in backward phase
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
|
||||||
|
|
||||||
# temp memory for backward is the index matrix to be discarded
|
# temp memory for backward is the index matrix to be discarded
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
|
||||||
temp=activation_size(index_matrix))
|
temp=compute_size_in_bytes(index_matrix))
|
||||||
|
|
||||||
# total cost
|
# total cost
|
||||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
|
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
|
||||||
|
@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
@ -35,11 +35,11 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
|
|||||||
|
|
||||||
# memory costs
|
# memory costs
|
||||||
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
|
||||||
fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
|
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
|
||||||
|
|
||||||
bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
|
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
|
||||||
parameter=0,
|
parameter=0,
|
||||||
temp=activation_size(outputs) * bwd_mem_tmp_factor,
|
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
|
||||||
buffer=0)
|
buffer=0)
|
||||||
|
|
||||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||||
|
@ -2,9 +2,9 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||||
|
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
from colossalai.fx.profiler.memory_utils import activation_size
|
|
||||||
from colossalai.fx.profiler.opcount import flop_mapping
|
|
||||||
|
|
||||||
from ..registry import meta_register
|
from ..registry import meta_register
|
||||||
|
|
||||||
|
@ -15,11 +15,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||||
from .registry import meta_register
|
from .registry import meta_register
|
||||||
|
|
||||||
__all__ = ['MetaInfo']
|
__all__ = ['ShardMetaInfo']
|
||||||
|
|
||||||
|
|
||||||
class MetaInfo:
|
class ShardMetaInfo:
|
||||||
"""MetaInfo class
|
"""ShardMetaInfo class
|
||||||
This class is used to store meta info based on sharding strategy and the given
|
This class is used to store meta info based on sharding strategy and the given
|
||||||
target function.
|
target function.
|
||||||
"""
|
"""
|
||||||
@ -46,9 +46,9 @@ class MetaInfo:
|
|||||||
# target function
|
# target function
|
||||||
self._target = target
|
self._target = target
|
||||||
|
|
||||||
# compute metainfo if possible
|
# compute shard_metainfo if possible
|
||||||
if self._strategy is not None and self._target is not None:
|
if self._strategy is not None and self._target is not None:
|
||||||
self.compute_metainfo()
|
self.compute_shard_metainfo()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strategy(self) -> ShardingStrategy:
|
def strategy(self) -> ShardingStrategy:
|
||||||
@ -62,13 +62,13 @@ class MetaInfo:
|
|||||||
def strategy(self, strategy: ShardingStrategy) -> None:
|
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
if self._strategy is not None and self._target is not None:
|
if self._strategy is not None and self._target is not None:
|
||||||
self.compute_metainfo()
|
self.compute_shard_metainfo()
|
||||||
|
|
||||||
@target.setter
|
@target.setter
|
||||||
def target(self, target: Callable) -> None:
|
def target(self, target: Callable) -> None:
|
||||||
self._target = target
|
self._target = target
|
||||||
if self._strategy is not None and self._target is not None:
|
if self._strategy is not None and self._target is not None:
|
||||||
self.compute_metainfo()
|
self.compute_shard_metainfo()
|
||||||
|
|
||||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
|
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
|
||||||
"""
|
"""
|
||||||
@ -93,7 +93,7 @@ class MetaInfo:
|
|||||||
|
|
||||||
return op_data
|
return op_data
|
||||||
|
|
||||||
def compute_metainfo(self):
|
def compute_shard_metainfo(self):
|
||||||
"""
|
"""
|
||||||
Compute meta info based on sharding strategy and the given target function.
|
Compute meta info based on sharding strategy and the given target function.
|
||||||
"""
|
"""
|
@ -4,7 +4,7 @@ import torch
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||||
from colossalai.tensor.comm_spec import CommSpec
|
from colossalai.tensor.comm_spec import CommSpec
|
||||||
@ -14,15 +14,15 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||||
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
|
||||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||||
origin_sharding_spec, target_sharding_spec)
|
origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
meta_info = MetaInfo()
|
meta_info = ShardMetaInfo()
|
||||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||||
# get mem cost for MetaInfo
|
# get mem cost for ShardMetaInfo
|
||||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||||
# extract user that has _meta_data and extract element length
|
# extract user that has _meta_data and extract element length
|
||||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||||
@ -36,12 +36,12 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
|||||||
|
|
||||||
meta_info.memory_cost = mem_cost
|
meta_info.memory_cost = mem_cost
|
||||||
|
|
||||||
# get computation cost for MetaInfo
|
# get computation cost for ShardMetaInfo
|
||||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||||
total_cost['backward'] * element_length,
|
total_cost['backward'] * element_length,
|
||||||
total_cost['total'] * element_length)
|
total_cost['total'] * element_length)
|
||||||
|
|
||||||
# get tensor shape for MetaInfo
|
# get tensor shape for ShardMetaInfo
|
||||||
origin_sharding_spec: ShardingSpec
|
origin_sharding_spec: ShardingSpec
|
||||||
target_sharding_spec: ShardingSpec
|
target_sharding_spec: ShardingSpec
|
||||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||||
@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
|||||||
return meta_info
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo:
|
||||||
"""
|
"""
|
||||||
This method is used to construct `MetaInto` for shape consistency node
|
This method is used to construct `MetaInto` for shape consistency node
|
||||||
"""
|
"""
|
||||||
@ -65,17 +65,17 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
|
|||||||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||||
user_node_index]
|
user_node_index]
|
||||||
|
|
||||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
|
||||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo:
|
||||||
# extract node_index and op_data_name
|
# extract node_index and op_data_name
|
||||||
node_index, op_data_name = node.args[2], node.args[3]
|
node_index, op_data_name = node.args[2], node.args[3]
|
||||||
|
|
||||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||||
if isinstance(comm_action.comm_spec, CommSpec):
|
if isinstance(comm_action.comm_spec, CommSpec):
|
||||||
# this case is for all_reduce, there will be no memory cost
|
# this case is for all_reduce, there will be no memory cost
|
||||||
meta_info = MetaInfo()
|
meta_info = ShardMetaInfo()
|
||||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||||
element_length = output_node._meta_data.element_size()
|
element_length = output_node._meta_data.element_size()
|
||||||
@ -93,7 +93,7 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
|
|||||||
# this case will be handled by shape consistency manager
|
# this case will be handled by shape consistency manager
|
||||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||||
'tgt_spec']
|
'tgt_spec']
|
||||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
return meta_info
|
return meta_info
|
||||||
|
|
||||||
@ -105,9 +105,9 @@ def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_di
|
|||||||
"""
|
"""
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.target == runtime_apply:
|
if node.target == runtime_apply:
|
||||||
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||||
elif node.target == runtime_comm_spec_apply:
|
elif node.target == runtime_comm_spec_apply:
|
||||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return gm
|
return gm
|
||||||
|
@ -7,7 +7,7 @@ import torch.fx
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
from colossalai.fx._compatibility import compatibility
|
from colossalai.fx._compatibility import compatibility
|
||||||
from colossalai.fx.profiler import GraphInfo
|
from colossalai.fx.profiler import GraphInfo
|
||||||
@ -96,12 +96,12 @@ class MetaInfoProp:
|
|||||||
"""
|
"""
|
||||||
Handle other kind of nodes
|
Handle other kind of nodes
|
||||||
"""
|
"""
|
||||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
meta_info = node.best_metainfo
|
meta_info = node.best_strategy_info
|
||||||
meta_info: MetaInfo
|
meta_info: ShardMetaInfo
|
||||||
|
|
||||||
# set data_ptr for input_tensor in MetaInfo class
|
# set data_ptr for input_tensor in ShardMetaInfo class
|
||||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||||
|
@ -4,7 +4,7 @@ from typing import Dict, List
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
CommAction,
|
CommAction,
|
||||||
CommType,
|
CommType,
|
||||||
@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||||||
runtime_apply,
|
runtime_apply,
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
node_to_index_dict[node], user_node_index))
|
node_to_index_dict[node], user_node_index))
|
||||||
if 'activation_checkpoint' in user_node.meta:
|
if hasattr(user_node.meta['info'], 'activation_checkpoint'):
|
||||||
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
|
MetaInfo(shape_consistency_node,
|
||||||
|
mod_dir=user_node.meta['info'].mod_dir,
|
||||||
|
activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
|
||||||
new_args = list(user_node.args)
|
new_args = list(user_node.args)
|
||||||
new_kwargs = dict(user_node.kwargs)
|
new_kwargs = dict(user_node.kwargs)
|
||||||
# the origin node may be a positional argument or key word argument of user node
|
# the origin node may be a positional argument or key word argument of user node
|
||||||
@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||||||
# substitute the origin node with comm_spec_apply_node
|
# substitute the origin node with comm_spec_apply_node
|
||||||
new_kwargs[str(node)] = comm_spec_apply_node
|
new_kwargs[str(node)] = comm_spec_apply_node
|
||||||
user.kwargs = new_kwargs
|
user.kwargs = new_kwargs
|
||||||
|
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||||
if 'activation_checkpoint' in node.meta:
|
MetaInfo(comm_spec_apply_node,
|
||||||
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
mod_dir=node.meta['info'].mod_dir,
|
||||||
|
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||||
|
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
CommAction,
|
CommAction,
|
||||||
@ -74,9 +75,9 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
|||||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||||
str(node))
|
str(node))
|
||||||
|
|
||||||
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
|
# attach the corresponding metainfo if node has the attribute `strategies_info`
|
||||||
if hasattr(node, 'metainfo_vector'):
|
if hasattr(node, 'strategies_info'):
|
||||||
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
|
setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
|
||||||
|
|
||||||
# the dict to get input sharding specs of user node
|
# the dict to get input sharding specs of user node
|
||||||
sharding_spec_convert_dict = {}
|
sharding_spec_convert_dict = {}
|
||||||
@ -172,8 +173,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
|||||||
# It will be used to replace the original node with processing node in slice object
|
# It will be used to replace the original node with processing node in slice object
|
||||||
node_pairs[node] = size_processing_node
|
node_pairs[node] = size_processing_node
|
||||||
size_processing_node._meta_data = node._meta_data
|
size_processing_node._meta_data = node._meta_data
|
||||||
if 'activation_checkpoint' in node.meta:
|
|
||||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
if hasattr(node.meta['info'], 'activation_checkpoint'):
|
||||||
|
MetaInfo(size_processing_node,
|
||||||
|
mod_dir=node.meta['info'].mod_dir,
|
||||||
|
activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
|
||||||
|
|
||||||
user_list = list(node.users.keys())
|
user_list = list(node.users.keys())
|
||||||
for user in user_list:
|
for user in user_list:
|
||||||
|
@ -6,6 +6,10 @@ import torch.nn as nn
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||||
@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
|||||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.tracer import ColoTracer
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
|||||||
|
|
||||||
|
|
||||||
def transform_to_sharded_model(gm: ColoGraphModule,
|
def transform_to_sharded_model(gm: ColoGraphModule,
|
||||||
|
meta_args: Dict,
|
||||||
solution: List[int],
|
solution: List[int],
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
strategies_constructor: StrategiesConstructor,
|
strategies_constructor: StrategiesConstructor,
|
||||||
@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule,
|
|||||||
strategies_constructor,
|
strategies_constructor,
|
||||||
overlap=overlap)
|
overlap=overlap)
|
||||||
gm = runtime_apply_pass(gm)
|
gm = runtime_apply_pass(gm)
|
||||||
|
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||||
|
|
||||||
@ -243,10 +247,13 @@ def initialize_model(model: nn.Module,
|
|||||||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||||
return a series of integers, but return the best strategies.
|
return a series of integers, but return the best strategies.
|
||||||
'''
|
'''
|
||||||
tracer = ColoTracer(trace_act_ckpt=True)
|
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
|
||||||
|
|
||||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||||
|
|
||||||
|
shape_prop_pass(gm, *meta_args.values())
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
strategies_constructor = build_strategy_constructor(graph,
|
strategies_constructor = build_strategy_constructor(graph,
|
||||||
@ -261,7 +268,9 @@ def initialize_model(model: nn.Module,
|
|||||||
if save_solver_solution:
|
if save_solver_solution:
|
||||||
torch.save(solution, solution_path)
|
torch.save(solution, solution_path)
|
||||||
|
|
||||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
|
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
|
||||||
|
overlap)
|
||||||
|
|
||||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||||
|
|
||||||
if return_solution:
|
if return_solution:
|
||||||
|
@ -2,8 +2,6 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
|
@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register
|
||||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
OperationData,
|
OperationData,
|
||||||
@ -258,7 +258,7 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||||
"""
|
"""
|
||||||
This method is inherited from NodeHandler. It will register the strategies first,
|
This method is inherited from NodeHandler. It will register the strategies first,
|
||||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
@ -266,15 +266,15 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||||||
# is not patched, we will use the default cost model to compute the cost.
|
# is not patched, we will use the default cost model to compute the cost.
|
||||||
# TODO: patch all torch functions and modules to make it clean
|
# TODO: patch all torch functions and modules to make it clean
|
||||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||||
metainfo_vector = []
|
strategies_info = []
|
||||||
for strategy in self.strategies_vector:
|
for strategy in self.strategies_vector:
|
||||||
metainfo = MetaInfo(strategy, target)
|
metainfo = ShardMetaInfo(strategy, target)
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
metainfo_vector.append(metainfo)
|
strategies_info.append(metainfo)
|
||||||
|
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "strategies_info", strategies_info)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
@ -313,7 +313,7 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||||
"""
|
"""
|
||||||
This method is inherited from NodeHandler. It will register the strategies first,
|
This method is inherited from NodeHandler. It will register the strategies first,
|
||||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class.
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
@ -321,15 +321,15 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||||||
# is not patched, we will use the default cost model to compute the cost.
|
# is not patched, we will use the default cost model to compute the cost.
|
||||||
# TODO: patch all torch functions and modules to make it clean
|
# TODO: patch all torch functions and modules to make it clean
|
||||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||||
metainfo_vector = []
|
strategies_info = []
|
||||||
for strategy in self.strategies_vector:
|
for strategy in self.strategies_vector:
|
||||||
metainfo = MetaInfo(strategy, target)
|
metainfo = ShardMetaInfo(strategy, target)
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
metainfo_vector.append(metainfo)
|
strategies_info.append(metainfo)
|
||||||
|
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "strategies_info", strategies_info)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
@ -137,9 +137,9 @@ class StrategiesConstructor:
|
|||||||
shard_option=self.solver_options.shard_option,
|
shard_option=self.solver_options.shard_option,
|
||||||
solver_perference=self.solver_options.solver_perference)
|
solver_perference=self.solver_options.solver_perference)
|
||||||
handler.register_strategy()
|
handler.register_strategy()
|
||||||
# attach metainfo_vector to node
|
# attach strategies_info to node
|
||||||
if hasattr(handler, 'metainfo_vector'):
|
if hasattr(handler, 'strategies_info'):
|
||||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
setattr(node, 'strategies_info', handler.strategies_info)
|
||||||
|
|
||||||
# call_function node
|
# call_function node
|
||||||
elif node.op == 'call_function':
|
elif node.op == 'call_function':
|
||||||
@ -150,9 +150,9 @@ class StrategiesConstructor:
|
|||||||
shard_option=self.solver_options.shard_option,
|
shard_option=self.solver_options.shard_option,
|
||||||
solver_perference=self.solver_options.solver_perference)
|
solver_perference=self.solver_options.solver_perference)
|
||||||
handler.register_strategy()
|
handler.register_strategy()
|
||||||
# attach metainfo_vector to node
|
# attach strategies_info to node
|
||||||
if hasattr(handler, 'metainfo_vector'):
|
if hasattr(handler, 'strategies_info'):
|
||||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
setattr(node, 'strategies_info', handler.strategies_info)
|
||||||
|
|
||||||
# call_method node
|
# call_method node
|
||||||
elif node.op == 'call_method':
|
elif node.op == 'call_method':
|
||||||
@ -163,9 +163,9 @@ class StrategiesConstructor:
|
|||||||
shard_option=self.solver_options.shard_option,
|
shard_option=self.solver_options.shard_option,
|
||||||
solver_perference=self.solver_options.solver_perference)
|
solver_perference=self.solver_options.solver_perference)
|
||||||
handler.register_strategy()
|
handler.register_strategy()
|
||||||
# attach metainfo_vector to node
|
# attach strategies_info to node
|
||||||
if hasattr(handler, 'metainfo_vector'):
|
if hasattr(handler, 'strategies_info'):
|
||||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
setattr(node, 'strategies_info', handler.strategies_info)
|
||||||
|
|
||||||
# output node
|
# output node
|
||||||
elif node.op == 'output':
|
elif node.op == 'output':
|
||||||
|
0
tests/test_analyzer/__init__.py
Normal file
0
tests/test_analyzer/__init__.py
Normal file
@ -1,10 +1,12 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.tracer import ColoTracer
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
@ -33,6 +35,7 @@ def recover_narrow(gm, narrow_node):
|
|||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||||
def test_size_value_converting_pass():
|
def test_size_value_converting_pass():
|
||||||
model = TestModule()
|
model = TestModule()
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
@ -40,14 +43,14 @@ def test_size_value_converting_pass():
|
|||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
meta_args = {'x': torch.rand(4, 8).to('meta')}
|
meta_args = {'x': torch.rand(4, 8).to('meta')}
|
||||||
input = torch.rand(4, 8)
|
input = torch.rand(4, 8)
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
|
||||||
x_node = list(graph.nodes)[0]
|
x_node = list(graph.nodes)[0]
|
||||||
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
|
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
|
||||||
setattr(x_node, 'sharding_spec', x_sharding_spec)
|
setattr(x_node, 'sharding_spec', x_sharding_spec)
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
gm = insert_narrow(gm, x_node)
|
gm = insert_narrow(gm, x_node)
|
||||||
|
shape_prop_pass(gm, *meta_args.values())
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
size = gm(input)
|
size = gm(input)
|
||||||
assert size == torch.Size([2, 8])
|
assert size == torch.Size([2, 8])
|
||||||
|
@ -4,7 +4,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
try:
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||||
|
NO_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NO_CODEGEN = True
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
@ -77,6 +82,7 @@ def check_conv_module(rank, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_bias_addition_module():
|
def test_bias_addition_module():
|
||||||
|
@ -8,13 +8,15 @@ import torch.nn as nn
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
try:
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||||
|
NO_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NO_CODEGEN = True
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.tracer import ColoTracer
|
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
@ -43,6 +45,7 @@ def check_act_ckpt(rank, world_size, port):
|
|||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
|
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
|
||||||
|
input = torch.rand(1, 64, HIDDEN_SIZE)
|
||||||
input_sample = {
|
input_sample = {
|
||||||
'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
|
'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
|
||||||
}
|
}
|
||||||
@ -54,10 +57,11 @@ def check_act_ckpt(rank, world_size, port):
|
|||||||
gm = initialize_model(model, input_sample, device_mesh)
|
gm = initialize_model(model, input_sample, device_mesh)
|
||||||
code = gm.module.graph.python_code('self').src
|
code = gm.module.graph.python_code('self').src
|
||||||
assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
|
assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
|
||||||
assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code
|
assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_mlp_layer():
|
def test_mlp_layer():
|
||||||
|
@ -6,7 +6,12 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
try:
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||||
|
NO_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NO_CODEGEN = True
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
@ -93,6 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_compatibility_with_ddp():
|
def test_compatibility_with_ddp():
|
||||||
|
@ -6,7 +6,12 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
try:
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||||
|
NO_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NO_CODEGEN = True
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
@ -101,6 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_auto_parallel_with_gemini():
|
def test_auto_parallel_with_gemini():
|
||||||
|
@ -5,8 +5,11 @@ import torch.nn as nn
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
@ -83,11 +86,12 @@ def test_repeat_blocks(model_cls):
|
|||||||
|
|
||||||
model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS)
|
model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS)
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')}
|
input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')}
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *input_sample.values())
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
node_list = list(graph.nodes)
|
node_list = list(graph.nodes)
|
||||||
|
@ -10,15 +10,23 @@ import torch.multiprocessing as mp
|
|||||||
import transformers
|
import transformers
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.initialize import (
|
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||||
ModuleWrapper,
|
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
build_strategy_constructor,
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
solve_solution,
|
|
||||||
transform_to_sharded_model,
|
try:
|
||||||
)
|
from colossalai.auto_parallel.tensor_shard.initialize import (
|
||||||
|
ModuleWrapper,
|
||||||
|
build_strategy_constructor,
|
||||||
|
solve_solution,
|
||||||
|
transform_to_sharded_model,
|
||||||
|
)
|
||||||
|
NO_CODEGEN = False
|
||||||
|
except:
|
||||||
|
NO_CODEGEN = True
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.shape_consistency import to_global
|
from colossalai.tensor.shape_consistency import to_global
|
||||||
@ -52,9 +60,8 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor
|
|||||||
param_sharding_spec = best_sharding_spec_dict[new_name]
|
param_sharding_spec = best_sharding_spec_dict[new_name]
|
||||||
grad_to_compare = copy.deepcopy(param_grad)
|
grad_to_compare = copy.deepcopy(param_grad)
|
||||||
param_grad_global = to_global(grad_to_compare, param_sharding_spec)
|
param_grad_global = to_global(grad_to_compare, param_sharding_spec)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03)
|
assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-05)
|
||||||
except:
|
except:
|
||||||
difference = param_grad_global - origin_param_grad
|
difference = param_grad_global - origin_param_grad
|
||||||
avg_diff = difference.abs().sum() / difference.numel()
|
avg_diff = difference.abs().sum() / difference.numel()
|
||||||
@ -66,7 +73,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM)
|
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
|
||||||
|
|
||||||
if model_cls == GPT2MLP:
|
if model_cls == GPT2MLP:
|
||||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
|
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
|
||||||
@ -111,15 +118,17 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||||||
# [[0, 1]
|
# [[0, 1]
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
|
|
||||||
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
graph = tracer.trace(root=model, meta_args=meta_input_sample)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *meta_input_sample.values())
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
|
||||||
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
|
||||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
|
gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh,
|
||||||
|
strategies_constructor)
|
||||||
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
gm = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||||
|
|
||||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
@ -176,6 +185,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.skipif(NO_CODEGEN, reason="no codegen module")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
|
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
@ -3,11 +3,12 @@ import torch.nn as nn
|
|||||||
import transformers
|
import transformers
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
@ -21,7 +22,7 @@ HIDDEN_DIM = 384
|
|||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||||
def test_self_attention_block(model_cls):
|
def test_self_attention_block(model_cls):
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
|
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
|
||||||
if model_cls == GPT2MLP:
|
if model_cls == GPT2MLP:
|
||||||
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
|
||||||
else:
|
else:
|
||||||
@ -33,7 +34,7 @@ def test_self_attention_block(model_cls):
|
|||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
if model_cls == GPT2MLP:
|
if model_cls == GPT2MLP:
|
||||||
input_sample = {
|
input_sample = {
|
||||||
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||||
@ -52,6 +53,7 @@ def test_self_attention_block(model_cls):
|
|||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *input_sample.values())
|
||||||
print(gm.graph)
|
print(gm.graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
solver_options = SolverOptions()
|
solver_options = SolverOptions()
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
|
||||||
|
|
||||||
|
|
||||||
class LinearModel(nn.Module):
|
class LinearModel(nn.Module):
|
||||||
@ -22,15 +25,14 @@ class LinearModel(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip('meta tensor has some bugs in 1.11')
|
||||||
def test_liveness_analysis():
|
def test_liveness_analysis():
|
||||||
model = LinearModel()
|
model = LinearModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
graph = tracer.trace(model,
|
meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')}
|
||||||
meta_args={
|
graph = tracer.trace(model, meta_args=meta_args)
|
||||||
'x1': torch.rand(4, 4, device='meta'),
|
|
||||||
'x2': torch.rand(4, 4, device='meta')
|
|
||||||
})
|
|
||||||
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
|
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *meta_args.values())
|
||||||
|
|
||||||
graph_analyser = GraphAnalyser(gm)
|
graph_analyser = GraphAnalyser(gm)
|
||||||
liveness_list = graph_analyser.liveness_analysis()
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
|
@ -24,7 +24,7 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||||
|
@ -17,7 +17,7 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
class MyModule(nn.Module):
|
class MyModule(nn.Module):
|
||||||
|
@ -24,7 +24,7 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||||
|
@ -23,7 +23,7 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
def _batchnorm_module_mem_test(rank, world_size, port):
|
def _batchnorm_module_mem_test(rank, world_size, port):
|
||||||
|
@ -24,7 +24,7 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
class SplitModule(nn.Module):
|
class SplitModule(nn.Module):
|
||||||
|
@ -22,7 +22,7 @@ from colossalai.utils import free_port
|
|||||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||||
|
@ -5,16 +5,19 @@ from typing import Dict, List
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
|
|
||||||
if torch.__version__ >= '1.12.0':
|
if torch.__version__ >= '1.12.0':
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo
|
||||||
|
|
||||||
|
|
||||||
def mem_test_for_node_strategy(rank: int,
|
def mem_test_for_node_strategy(rank: int,
|
||||||
@ -30,14 +33,16 @@ def mem_test_for_node_strategy(rank: int,
|
|||||||
model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy(
|
model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy(
|
||||||
input_kwargs)
|
input_kwargs)
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
input_sample = {}
|
input_sample = {}
|
||||||
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
|
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
|
||||||
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
|
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
|
||||||
for meta_kwarg_name, input_kwarg in input_kwargs.items():
|
for meta_kwarg_name, input_kwarg in input_kwargs.items():
|
||||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *input_sample.values())
|
||||||
|
gm.recompile()
|
||||||
solver_options = SolverOptions()
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
@ -108,10 +113,10 @@ def mem_test_for_node_strategy(rank: int,
|
|||||||
|
|
||||||
# estimated memory
|
# estimated memory
|
||||||
if target_node.op == "call_module":
|
if target_node.op == "call_module":
|
||||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
|
metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index],
|
||||||
target_node.graph.owning_module.get_submodule(target_node.target))
|
target_node.graph.owning_module.get_submodule(target_node.target))
|
||||||
else:
|
else:
|
||||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
|
metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target)
|
||||||
|
|
||||||
print("estimated memory:")
|
print("estimated memory:")
|
||||||
print(
|
print(
|
||||||
|
@ -1,126 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
|
||||||
|
|
||||||
|
|
||||||
def _param_resharding_cost_assertion(node):
|
|
||||||
for strategy in node.strategies_vector:
|
|
||||||
for prev_node, resharding_cost in strategy.resharding_costs.items():
|
|
||||||
if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM:
|
|
||||||
for cost in resharding_cost:
|
|
||||||
assert cost.fwd == 0
|
|
||||||
assert cost.bwd == 0
|
|
||||||
assert cost.total == 0
|
|
||||||
|
|
||||||
|
|
||||||
class LinearModel(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_features, out_features):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(in_features, out_features)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.linear(x)
|
|
||||||
x = x * 2
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = x * 2
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
|
||||||
def test_linear_module():
|
|
||||||
model = LinearModel(4, 8)
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
|
||||||
mesh_shape = (2, 2)
|
|
||||||
# [[0, 1]
|
|
||||||
# [2, 3]]
|
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
|
||||||
tracer = ColoTracer()
|
|
||||||
# graph():
|
|
||||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
|
||||||
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
|
|
||||||
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
|
|
||||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
|
|
||||||
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
|
|
||||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
|
||||||
# return mul
|
|
||||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
|
|
||||||
# def forward(self, x : torch.Tensor):
|
|
||||||
# linear_weight = self.linear.weight
|
|
||||||
# linear_bias = self.linear.bias
|
|
||||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
|
||||||
# add = linear + linear_bias; linear = linear_bias = None
|
|
||||||
# mul = add * 2; add = None
|
|
||||||
# return mul
|
|
||||||
gm = ColoGraphModule(model, graph)
|
|
||||||
gm.recompile()
|
|
||||||
node_list = list(graph.nodes)
|
|
||||||
|
|
||||||
solver_options = SolverOptions()
|
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
|
||||||
strategies_constructor.build_strategies_and_cost()
|
|
||||||
linear_node = node_list[3]
|
|
||||||
_param_resharding_cost_assertion(linear_node)
|
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
|
||||||
def test_conv_module():
|
|
||||||
model = ConvModel(3, 6, 2)
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
|
||||||
mesh_shape = (2, 2)
|
|
||||||
# [[0, 1]
|
|
||||||
# [2, 3]]
|
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
|
||||||
tracer = ColoTracer()
|
|
||||||
# graph():
|
|
||||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
|
||||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
|
||||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
|
||||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
|
|
||||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
|
||||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
|
||||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
|
||||||
# return mul
|
|
||||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
|
|
||||||
# def forward(self, x : torch.Tensor):
|
|
||||||
# conv_weight = self.conv.weight
|
|
||||||
# conv_bias = self.conv.bias
|
|
||||||
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
|
|
||||||
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
|
|
||||||
# add = conv2d + view; conv2d = view = None
|
|
||||||
# mul = add * 2; add = None
|
|
||||||
# return mul
|
|
||||||
gm = ColoGraphModule(model, graph)
|
|
||||||
|
|
||||||
gm.recompile()
|
|
||||||
node_list = list(graph.nodes)
|
|
||||||
conv_node = node_list[3]
|
|
||||||
solver_options = SolverOptions()
|
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
|
||||||
strategies_constructor.build_strategies_and_cost()
|
|
||||||
_param_resharding_cost_assertion(conv_node)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_linear_module()
|
|
||||||
test_conv_module()
|
|
@ -1,86 +0,0 @@
|
|||||||
import copy
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.logging import disable_existing_loggers
|
|
||||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, c_in, c_out):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = torch.flatten(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def check_apply(rank, world_size, port):
|
|
||||||
disable_existing_loggers()
|
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
input = torch.rand(4, 4, 4, 4).cuda()
|
|
||||||
test_input = copy.deepcopy(input)
|
|
||||||
# graph():
|
|
||||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
|
||||||
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
|
||||||
# return conv
|
|
||||||
model = ConvModel(4, 4).cuda()
|
|
||||||
test_model = copy.deepcopy(model)
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
|
||||||
mesh_shape = (2, 2)
|
|
||||||
# [[0, 1]
|
|
||||||
# [2, 3]]
|
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
|
||||||
meta_args = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
|
||||||
gm = initialize_model(model, meta_args, device_mesh)
|
|
||||||
|
|
||||||
output = gm(input)
|
|
||||||
origin_output = test_model(test_input)
|
|
||||||
assert output.equal(origin_output)
|
|
||||||
origin_loss = origin_output.sum()
|
|
||||||
loss = output.sum()
|
|
||||||
|
|
||||||
origin_loss.backward()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 1)
|
|
||||||
grad_1 = test_model.conv.weight.grad.narrow(0, 1, 1)
|
|
||||||
grad_2 = test_model.conv.weight.grad.narrow(0, 2, 1)
|
|
||||||
grad_3 = test_model.conv.weight.grad.narrow(0, 3, 1)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
assert_close(gm.module.conv.weight.grad.data, grad_0.data)
|
|
||||||
elif rank == 1:
|
|
||||||
assert_close(gm.module.conv.weight.grad.data, grad_1.data)
|
|
||||||
elif rank == 2:
|
|
||||||
assert_close(gm.module.conv.weight.grad.data, grad_2.data)
|
|
||||||
elif rank == 3:
|
|
||||||
assert_close(gm.module.conv.weight.grad.data, grad_3.data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'rank {rank} does not exist.')
|
|
||||||
|
|
||||||
|
|
||||||
# skip this test due to pulp not installed in CI environment
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
|
||||||
@pytest.mark.dist
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_apply():
|
|
||||||
world_size = 4
|
|
||||||
run_func = partial(check_apply, world_size=world_size, port=free_port())
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_apply()
|
|
@ -2,11 +2,13 @@ import torch
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torchvision.models import resnet50
|
from torchvision.models import resnet50
|
||||||
|
|
||||||
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||||
|
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
@ -20,7 +22,7 @@ def test_cost_graph():
|
|||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(bias_addition_split=True)
|
||||||
model = resnet50(num_classes=100000)
|
model = resnet50(num_classes=100000)
|
||||||
input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')}
|
input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')}
|
||||||
|
|
||||||
@ -50,6 +52,7 @@ def test_cost_graph():
|
|||||||
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
|
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
|
||||||
# return fc
|
# return fc
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
shape_prop_pass(gm, *input_sample.values())
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
solver_options = SolverOptions()
|
solver_options = SolverOptions()
|
||||||
|
Loading…
Reference in New Issue
Block a user