mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[autoparallel] fix bias addition module (#1800)
This commit is contained in:
@@ -43,7 +43,7 @@ class BiasAdditionConv(BiasAdditionModule):
|
||||
bias_shape[0] = -1
|
||||
bias_reshape_node_kind = 'call_method'
|
||||
bias_reshape_node_target = 'view'
|
||||
bias_reshape_node_args = (self.bias_proxy, bias_shape)
|
||||
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
|
||||
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
|
||||
bias_reshape_node_args, {})
|
||||
return bias_reshape_proxy
|
||||
|
@@ -58,7 +58,7 @@ def torch_bmm(input, mat2, *, out=None):
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.linear)
|
||||
def torch_linear(input, mat2, *, out=None):
|
||||
def torch_linear(input, mat2, bias=None, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
output_shape = list(input.shape)
|
||||
|
Reference in New Issue
Block a user