mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,12 +3,12 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.types import _bool, _device, _dtype
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
from torch.types import _device
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
|
||||
|
||||
__all__ = ['MetaTensor', 'MetaTensorMode']
|
||||
__all__ = ["MetaTensor", "MetaTensorMode"]
|
||||
|
||||
|
||||
def register_storage(r, data_ptr_fn=None):
|
||||
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
|
||||
|
||||
# a hack of inplace execution in PyTorch
|
||||
def _assert_alias(func):
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
|
||||
)
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
|
||||
requires_grad=requires_grad) # deceive the frontend for aten selections
|
||||
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
|
||||
requires_grad=requires_grad,
|
||||
) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
val = elem.data_ptr()
|
||||
data_ptr_fn = lambda: val
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
r._tensor = r._tensor.to(torch.device("meta"))
|
||||
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
register_storage(r._tensor, data_ptr_fn)
|
||||
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
|
||||
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor):
|
||||
device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
x = x.to(torch.device("meta"))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
if "device" in kwargs:
|
||||
device = kwargs["device"]
|
||||
kwargs["device"] = torch.device("meta")
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
# here we detect whether or not the execution generates a physical copy
|
||||
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
|
||||
nonlocal device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
device = x
|
||||
return torch.device('meta')
|
||||
return torch.device("meta")
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, device=device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
if self.device.type == "cpu":
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
return self.to(*args, device="cpu", **kwargs)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
return self.to(device="cuda:0", non_blocking=non_blocking)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._tensor.data_ptr()
|
||||
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def _dummy(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def _new(*args, orig_new=torch.empty, **kwargs):
|
||||
return MetaTensor(orig_new(*args, **{
|
||||
**kwargs, 'device': 'meta'
|
||||
}),
|
||||
device=kwargs.get('device', torch.device('cpu')))
|
||||
return MetaTensor(
|
||||
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
|
||||
)
|
||||
|
||||
for func in _TorchOverrideableFactoryMethod:
|
||||
self.torch_overrides[func] = getattr(torch, func)
|
||||
|
Reference in New Issue
Block a user