[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, **kwargs)
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, **kwargs)
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, **kwargs)
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None:
return F.conv_transpose1d(input, weight, **kwargs)
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None:
return F.conv_transpose2d(input, weight, **kwargs)
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None:
return F.conv_transpose3d(input, weight, **kwargs)
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
new_kwargs = kwargs
new_kwargs['bias'] = None
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')

View File

@@ -155,7 +155,7 @@ class ColoTracer(Tracer):
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
def trace(self,