[utils] lazy init. (#2148)

* [utils] lazy init.

* [utils] remove description.

* [utils] complete.

* [utils] finalize.

* [utils] fix names.
This commit is contained in:
Super Daniel
2023-01-20 10:49:00 +08:00
committed by GitHub
parent 72341e65f4
commit 35c0c0006e
2 changed files with 461 additions and 24 deletions

View File

@@ -1,6 +1,4 @@
import uuid
from copy import deepcopy
from typing import Optional
import torch
from torch.types import _bool, _device, _dtype
@@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor):
_tensor: torch.Tensor
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
@@ -47,7 +43,7 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=fake_device if fake_device is not None else elem.device,
device=fake_device if fake_device is not None else torch.device('cpu'),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
@@ -59,8 +55,8 @@ class MetaTensor(torch.Tensor):
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@@ -76,13 +72,13 @@ class MetaTensor(torch.Tensor):
x = x.to(torch.device('meta'))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
@@ -118,23 +114,24 @@ class MetaTensor(torch.Tensor):
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, _device):
device = arg
if 'device' in kwargs:
device = kwargs['device']
result = super().to(*args, **kwargs)
if device is not None:
result = MetaTensor(result, fake_device=device)
return result
fake_device = None
def replace(x):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
return 'meta'
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
def cuda(self, *args, **kwargs):
if self.device.type == 'cuda':
return self.to(*args, **kwargs)
return self.to(*args, device='cuda', **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)