mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class MetaInfoProp:
|
||||
|
||||
def __init__(self, module: GraphModule) -> None:
|
||||
self.module = module
|
||||
self.func_dict = {
|
||||
'placeholder': self.placeholder_handler,
|
||||
'get_attr': self.get_attr_handler,
|
||||
'output': self.output_handler,
|
||||
'call_function': self.node_handler,
|
||||
'call_module': self.node_handler,
|
||||
'call_method': self.node_handler,
|
||||
"placeholder": self.placeholder_handler,
|
||||
"get_attr": self.get_attr_handler,
|
||||
"output": self.output_handler,
|
||||
"call_function": self.node_handler,
|
||||
"call_module": self.node_handler,
|
||||
"call_method": self.node_handler,
|
||||
}
|
||||
|
||||
def _set_data_ptr(self, x):
|
||||
@@ -46,7 +45,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
Check if the node is inplace operation.
|
||||
"""
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||
elif node.op == "call_function":
|
||||
return node.target in OUTPUT_SAVED_OPS
|
||||
@@ -66,7 +65,7 @@ class MetaInfoProp:
|
||||
Handle the placeholder node.
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
out = _normalize_tuple(getattr(node, "_meta_data", None))
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@@ -96,7 +95,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_strategy_info
|
||||
meta_info: ShardMetaInfo
|
||||
@@ -126,7 +125,8 @@ class MetaInfoProp:
|
||||
for tensor in par.meta.get("fwd_out", []):
|
||||
tensor: torch.Tensor
|
||||
target_input_tensor = next(
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
|
||||
)
|
||||
if target_input_tensor is not None:
|
||||
target_input_tensor.data_ptr = tensor.data_ptr
|
||||
|
||||
|
Reference in New Issue
Block a user