mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
# Inputs contains the shapes of two matrices.
|
||||
input_shapes = [v.shape for v in inputs]
|
||||
assert len(input_shapes) == 2, input_shapes
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
|
||||
# There are three cases: 1) gemm, 2) gemv, 3) dot
|
||||
if all(len(shape) == 2 for shape in input_shapes):
|
||||
# gemm
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
elif all(len(shape) == 1 for shape in input_shapes):
|
||||
# dot
|
||||
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
|
||||
|
||||
# expand shape
|
||||
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
|
||||
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
|
||||
else:
|
||||
# gemv
|
||||
if len(input_shapes[0]) == 1:
|
||||
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
|
||||
input_shapes.reverse()
|
||||
else:
|
||||
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
|
||||
|
||||
# expand the shape of the vector to [batch size, 1]
|
||||
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
|
||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||
return flops
|
||||
|
||||
|
@@ -1,8 +1,12 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.fx.graph import CodeGen
|
||||
except:
|
||||
pass
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_format_target,
|
||||
@@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
||||
"""
|
||||
Check if the node could end the ckpt region at `ckpt_level`
|
||||
"""
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
return node.meta['info'].to_recompute[ckpt_level] is not None
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
|
||||
return True
|
||||
|
||||
|
||||
@@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
@@ -152,12 +156,12 @@ def emit_ckpt_func(body,
|
||||
|
||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
@@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
node_idx = 0
|
||||
|
@@ -112,7 +112,7 @@ class MetaInfo:
|
||||
|
||||
# should keep the same whenever manipulated
|
||||
# ============================= Invariant ==================================
|
||||
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
to_offload: Optional[bool] = False
|
||||
sharding_spec: str = 'RR'
|
||||
|
||||
|
@@ -237,7 +237,14 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||
|
||||
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||
def wrap_fn(elem, device=device):
|
||||
if isinstance(elem, torch.Tensor):
|
||||
return MetaTensor(elem, device=device)
|
||||
else:
|
||||
return elem
|
||||
|
||||
with self._mode:
|
||||
return super().run(*tree_map(wrap_fn, args))
|
||||
|
||||
|
@@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||
def conv1d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv1d(input, weight, **kwargs)
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||
def conv2d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv2d(input, weight, **kwargs)
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||
def conv3d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv3d(input, weight, **kwargs)
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||
def conv_transpose1d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose1d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose1d(input, weight, **kwargs)
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||
def conv_transpose2d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose2d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_pair(1),
|
||||
padding=_pair(0),
|
||||
output_padding=_pair(0),
|
||||
groups=1,
|
||||
dilation=_pair(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose2d(input, weight, **kwargs)
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||
def conv_transpose3d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose3d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose3d(input, weight, **kwargs)
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||
|
@@ -155,7 +155,7 @@ class ColoTracer(Tracer):
|
||||
|
||||
def create_node(self, *args, **kwargs) -> Node:
|
||||
node = super().create_node(*args, **kwargs)
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
||||
return node
|
||||
|
||||
def trace(self,
|
||||
|
Reference in New Issue
Block a user