mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
from torch.types import _bool, _device, _dtype
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
from torch.types import _device
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .constants import ALIAS_ATEN
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
__all__ = ["MetaTensor"]
|
||||
|
||||
|
||||
def set_data_ptr(x):
|
||||
@@ -43,12 +43,13 @@ class MetaTensor(torch.Tensor):
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
|
||||
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
||||
device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
|
||||
requires_grad=elem.requires_grad,
|
||||
) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
r._tensor = r._tensor.to(torch.device("meta"))
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
set_data_ptr(r._tensor)
|
||||
return r
|
||||
@@ -69,15 +70,15 @@ class MetaTensor(torch.Tensor):
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor):
|
||||
fake_device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
x = x.to(torch.device("meta"))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
if "device" in kwargs:
|
||||
fake_device = kwargs["device"]
|
||||
kwargs["device"] = torch.device("meta")
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
@@ -93,7 +94,7 @@ class MetaTensor(torch.Tensor):
|
||||
if isinstance(x, torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if not x.is_meta:
|
||||
x = x.to(torch.device('meta'))
|
||||
x = x.to(torch.device("meta"))
|
||||
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
||||
@@ -120,18 +121,18 @@ class MetaTensor(torch.Tensor):
|
||||
nonlocal fake_device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
fake_device = x
|
||||
return 'meta'
|
||||
return "meta"
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, fake_device=fake_device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
if self.device.type == "cpu":
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
return self.to(*args, device="cpu", **kwargs)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
return self.to(device="cuda:0", non_blocking=non_blocking)
|
||||
|
Reference in New Issue
Block a user