[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
# There are three cases: 1) gemm, 2) gemv, 3) dot
if all(len(shape) == 2 for shape in input_shapes):
# gemm
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
elif all(len(shape) == 1 for shape in input_shapes):
# dot
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
# expand shape
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
else:
# gemv
if len(input_shapes[0]) == 1:
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
input_shapes.reverse()
else:
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
# expand the shape of the vector to [batch size, 1]
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops

View File

@@ -1,8 +1,12 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
try:
from torch.fx.graph import CodeGen
except:
pass
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_format_target,
@@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
if len(node.meta['info'].to_recompute) > ckpt_level:
return node.meta['info'].to_recompute[ckpt_level] is not None
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
return True
@@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None
for idx, node in enumerate(node_list):
if len(node.meta['info'].to_recompute) > ckpt_level:
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -152,12 +156,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
node_list = list(nodes)
node_idx = 0

View File

@@ -112,7 +112,7 @@ class MetaInfo:
# should keep the same whenever manipulated
# ============================= Invariant ==================================
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
sharding_spec: str = 'RR'

View File

@@ -237,7 +237,14 @@ class ShapeProp(torch.fx.Interpreter):
Returns:
Any: The value returned from executing the Module
"""
wrap_fn = lambda elem: MetaTensor(elem, device=device)
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
def wrap_fn(elem, device=device):
if isinstance(elem, torch.Tensor):
return MetaTensor(elem, device=device)
else:
return elem
with self._mode:
return super().run(*tree_map(wrap_fn, args))

View File

@@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, **kwargs)
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, **kwargs)
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, **kwargs)
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None:
return F.conv_transpose1d(input, weight, **kwargs)
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None:
return F.conv_transpose2d(input, weight, **kwargs)
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None:
return F.conv_transpose3d(input, weight, **kwargs)
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')

View File

@@ -155,7 +155,7 @@ class ColoTracer(Tracer):
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
def trace(self,