mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||
def conv1d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv1d(input, weight, **kwargs)
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||
def conv2d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv2d(input, weight, **kwargs)
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||
def conv3d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv3d(input, weight, **kwargs)
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||
def conv_transpose1d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose1d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose1d(input, weight, **kwargs)
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||
def conv_transpose2d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose2d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_pair(1),
|
||||
padding=_pair(0),
|
||||
output_padding=_pair(0),
|
||||
groups=1,
|
||||
dilation=_pair(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose2d(input, weight, **kwargs)
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||
def conv_transpose3d_impl(input, weight, **kwargs):
|
||||
bias = getattr(kwargs, 'bias', None)
|
||||
def conv_transpose3d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1)):
|
||||
if bias is None:
|
||||
return F.conv_transpose3d(input, weight, **kwargs)
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
else:
|
||||
new_kwargs = kwargs
|
||||
new_kwargs['bias'] = None
|
||||
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||
|
@@ -155,7 +155,7 @@ class ColoTracer(Tracer):
|
||||
|
||||
def create_node(self, *args, **kwargs) -> Node:
|
||||
node = super().create_node(*args, **kwargs)
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
||||
return node
|
||||
|
||||
def trace(self,
|
||||
|
Reference in New Issue
Block a user