mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[utils] lazy init. (#2148)
* [utils] lazy init. * [utils] remove description. * [utils] complete. * [utils] finalize. * [utils] fix names.
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.types import _bool, _device, _dtype
|
||||
@@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor):
|
||||
|
||||
_tensor: torch.Tensor
|
||||
|
||||
__slots__ = ['_tensor']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, fake_device=None):
|
||||
# Avoid multiple wrapping
|
||||
@@ -47,7 +43,7 @@ class MetaTensor(torch.Tensor):
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=fake_device if fake_device is not None else elem.device,
|
||||
device=fake_device if fake_device is not None else torch.device('cpu'),
|
||||
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
@@ -59,8 +55,8 @@ class MetaTensor(torch.Tensor):
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
|
||||
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
@@ -76,13 +72,13 @@ class MetaTensor(torch.Tensor):
|
||||
x = x.to(torch.device('meta'))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
fake_device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
@@ -118,23 +114,24 @@ class MetaTensor(torch.Tensor):
|
||||
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
||||
"""
|
||||
# this imitates c++ function in the way of @overload
|
||||
device = None
|
||||
for arg in args:
|
||||
if isinstance(arg, str) or isinstance(arg, _device):
|
||||
device = arg
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
result = super().to(*args, **kwargs)
|
||||
if device is not None:
|
||||
result = MetaTensor(result, fake_device=device)
|
||||
return result
|
||||
fake_device = None
|
||||
|
||||
def replace(x):
|
||||
nonlocal fake_device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
fake_device = x
|
||||
return 'meta'
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, fake_device=fake_device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
if self.device.type == 'cuda':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cuda', **kwargs)
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
|
Reference in New Issue
Block a user