mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
|
||||
_tensor: torch.Tensor
|
||||
_node: Node
|
||||
|
||||
__slots__ = ['_tensor', '_node']
|
||||
__slots__ = ["_tensor", "_node"]
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
|
||||
@@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
|
||||
dtype=tensor.dtype,
|
||||
layout=tensor.layout,
|
||||
device=fake_device if fake_device is not None else tensor.device,
|
||||
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
|
||||
requires_grad=tensor.requires_grad,
|
||||
) # deceive the frontend for aten selections
|
||||
r._tensor = tensor
|
||||
if placeholder:
|
||||
if name is None:
|
||||
name = 'input'
|
||||
r._node = graph.create_node('placeholder',
|
||||
'placeholder', (graph._root,),
|
||||
name=namespace.create_name(name, tensor))
|
||||
name = "input"
|
||||
r._node = graph.create_node(
|
||||
"placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
|
||||
)
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
r._tensor = r._tensor.to(torch.device("meta"))
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def unwrap(x):
|
||||
nonlocal fake_device
|
||||
if isinstance(x, MetaProxy):
|
||||
@@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
|
||||
# assert not isinstance(x, MetaProxy)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
x = x.to(torch.device("meta"))
|
||||
return x
|
||||
|
||||
def get_node(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
|
||||
x = MetaProxy(x, placeholder=True, name='weight')
|
||||
return x if not hasattr(x, '_node') else x._node
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
|
||||
x = MetaProxy(x, placeholder=True, name="weight")
|
||||
return x if not hasattr(x, "_node") else x._node
|
||||
|
||||
args_node = tree_map(get_node, args)
|
||||
kwargs_node = tree_map(get_node, kwargs)
|
||||
node = graph.create_node('call_function', func, args_node, kwargs_node)
|
||||
node = graph.create_node("call_function", func, args_node, kwargs_node)
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
if "device" in kwargs:
|
||||
fake_device = kwargs["device"]
|
||||
kwargs["device"] = torch.device("meta")
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
@@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
return MetaProxy(
|
||||
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
|
||||
x = x.to(torch.device("meta"))
|
||||
return (
|
||||
MetaProxy(x, fake_device=fake_device)
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
|
||||
else x
|
||||
)
|
||||
|
||||
def set_node(x):
|
||||
x._node = node
|
||||
@@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
|
||||
|
||||
for tensor in normalize_tuple(out):
|
||||
if is_autogradable(tensor) and tensor.requires_grad:
|
||||
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
|
||||
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
|
||||
torch.autograd.backward(tensor,
|
||||
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
|
||||
retain_graph=True)
|
||||
grad = (
|
||||
torch.empty_like(tensor._tensor, device=torch.device("meta"))
|
||||
if isinstance(tensor, MetaProxy)
|
||||
else torch.empty_like(tensor, device=torch.device("meta"))
|
||||
)
|
||||
torch.autograd.backward(
|
||||
tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
|
||||
)
|
||||
return graph
|
||||
|
Reference in New Issue
Block a user