[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -54,7 +54,7 @@ def _current_device(module):
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
return torch.device("cpu")
@compatibility(is_backward_compatible=False)
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
_custom_dispatch_func = {}
_mode = MetaTensorMode()
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem):
def _convert_meta(t: torch.Tensor):
if t.device == 'meta':
if t.device == "meta":
return t
else:
return t.to('meta')
return t.to("meta")
if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
if getattr(self, "_is_param", False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
if n.op == 'call_module':
if n.op == "call_module":
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
n_info.parameters.update({
k.name: MetaTensor(v)
for k, v in zip(n.args, args)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
})
n_info.parameters.update(
{
k.name: MetaTensor(v)
for k, v in zip(n.args, args)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
}
)
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
v for v in kwargs.values() if is_pure_tensor(v)
)
# align with SPMD
if isinstance(r, (tuple, list)):
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
args[0], torch.nn.parameter.Parameter
):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)