mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[autoparallel] record parameter attribute in colotracer (#2217)
* [autoparallel] record parameter attribute in collotracer * [autoparallel] fix construct_meta_info bug
This commit is contained in:
@@ -174,8 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
meta_info = construct_meta_info(node, user_node)
|
||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||
# meta_info = construct_meta_info(node, user_node)
|
||||
# setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
||||
@@ -229,6 +229,15 @@ class ColoTracer(Tracer):
|
||||
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
||||
|
||||
if kind == "call_function":
|
||||
# Our meta data will not record the nn.parameter.Parameter attribute。
|
||||
# It works fine in most of the case, but it may cause some problems after
|
||||
# the bias addition manipulation.
|
||||
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
|
||||
# added by the bias addition manipulation following the get_attr node.
|
||||
convert_to_parameter = False
|
||||
if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
|
||||
torch.nn.parameter.Parameter):
|
||||
convert_to_parameter = True
|
||||
# fetch patched function
|
||||
if meta_patched_function.has(target):
|
||||
meta_target = meta_patched_function.get(target)
|
||||
@@ -241,7 +250,18 @@ class ColoTracer(Tracer):
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
if isinstance(meta_out, torch.Tensor):
|
||||
meta_out = meta_out.to(device="meta")
|
||||
if convert_to_parameter:
|
||||
meta_out = torch.nn.Parameter(meta_out)
|
||||
|
||||
elif kind == "call_method":
|
||||
# Our meta data will not record the nn.parameter.Parameter attribute。
|
||||
# It works fine in most of the case, but it may cause some problems after
|
||||
# the bias addition manipulation.
|
||||
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
|
||||
# added by the bias addition manipulation following the get_attr node.
|
||||
convert_to_parameter = False
|
||||
if target in (torch.Tensor.view,) and isinstance(args_metas[0], torch.nn.parameter.Parameter):
|
||||
convert_to_parameter = True
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
|
||||
# fetch patched method
|
||||
@@ -251,6 +271,8 @@ class ColoTracer(Tracer):
|
||||
meta_target = method
|
||||
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
if convert_to_parameter:
|
||||
meta_out = torch.nn.Parameter(meta_out)
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||||
|
||||
Reference in New Issue
Block a user