mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -54,7 +54,7 @@ def _current_device(module):
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except StopIteration:
|
||||
return torch.device('cpu')
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
>>> # do something here
|
||||
>>> return torch.empty(output_shape, device=output_device)
|
||||
"""
|
||||
|
||||
_custom_dispatch_func = {}
|
||||
_mode = MetaTensorMode()
|
||||
|
||||
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
def unwrap_fn(elem):
|
||||
|
||||
def _convert_meta(t: torch.Tensor):
|
||||
if t.device == 'meta':
|
||||
if t.device == "meta":
|
||||
return t
|
||||
else:
|
||||
return t.to('meta')
|
||||
return t.to("meta")
|
||||
|
||||
if isinstance(elem, MetaTensor):
|
||||
if getattr(self, '_is_param', False):
|
||||
if getattr(self, "_is_param", False):
|
||||
return torch.nn.Parameter(_convert_meta(elem._tensor))
|
||||
return _convert_meta(elem._tensor)
|
||||
|
||||
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
submod = self.fetch_attr(n.target)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
||||
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
||||
|
||||
else:
|
||||
n_info.parameters.update({
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
})
|
||||
n_info.parameters.update(
|
||||
{
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
}
|
||||
)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
||||
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
||||
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
|
||||
v for v in kwargs.values() if is_pure_tensor(v)
|
||||
)
|
||||
|
||||
# align with SPMD
|
||||
if isinstance(r, (tuple, list)):
|
||||
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
||||
return r
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
||||
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
else:
|
||||
return res
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the result.
|
||||
|
||||
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
|
||||
convert_to_parameter = False
|
||||
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
|
||||
args[0], torch.nn.parameter.Parameter):
|
||||
args[0], torch.nn.parameter.Parameter
|
||||
):
|
||||
convert_to_parameter = True
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
|
Reference in New Issue
Block a user