[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -29,7 +29,7 @@ from .registry import (
meta_patched_module,
)
__all__ = ['ColoTracer']
__all__ = ["ColoTracer"]
class TracerType(enum.Enum):
@@ -103,7 +103,7 @@ class ColoTracer(Tracer):
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
@@ -160,22 +160,27 @@ class ColoTracer(Tracer):
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
lambda node: ParameterProxy(self, node, n, attr_val))
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
parameter_proxy_cache)
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
@@ -190,7 +195,7 @@ class ColoTracer(Tracer):
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
return self.create_proxy("call_module", module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
@@ -211,7 +216,6 @@ class ColoTracer(Tracer):
raise ValueError(f"Unrecognized tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target]
return meta_out
@@ -235,8 +239,9 @@ class ColoTracer(Tracer):
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False
if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
torch.nn.parameter.Parameter):
if target in (torch.transpose, torch.reshape) and isinstance(
args_metas[0], torch.nn.parameter.Parameter
):
convert_to_parameter = True
# fetch patched function
if meta_patched_function.has(target):
@@ -309,10 +314,12 @@ class ColoTracer(Tracer):
return meta_out
def trace(self,
root: nn.Module,
concrete_args: Optional[Dict[str, Tensor]] = None,
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
def trace(
self,
root: nn.Module,
concrete_args: Optional[Dict[str, Tensor]] = None,
meta_args: Optional[Dict[str, Tensor]] = None,
) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
@@ -341,9 +348,7 @@ class ColoTracer(Tracer):
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
@@ -354,7 +359,8 @@ class ColoTracer(Tracer):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
)
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -363,11 +369,13 @@ class ColoTracer(Tracer):
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
if not should_be_meta:
assert not torch.is_tensor(v) or not v.is_meta, \
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
assert (
not torch.is_tensor(v) or not v.is_meta
), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer"
else:
assert v.is_meta == should_be_meta, \
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
assert (
v.is_meta == should_be_meta
), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer"
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
@@ -442,7 +450,6 @@ class ColoTracer(Tracer):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part
@@ -455,7 +462,8 @@ class ColoTracer(Tracer):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.")
"We do not implement the backward pass as we only trace the forward pass."
)
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -470,12 +478,11 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
def wrap_tensor_constructor_method(target):
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
@@ -518,12 +525,10 @@ def wrap_tensor_constructor_method(target):
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs)
proxy = tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name):
def impl(self, rhs):
target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)