[fx] add more meta_registry for MetaTensor execution. (#2000)

* [sc] add examples for auto checkpoint.

* merge upstream

* [fx] add more meta_registry for MetaTensor execution.
This commit is contained in:
Super Daniel
2022-11-23 10:55:46 +08:00
committed by GitHub
parent d00d905b86
commit 2edbef13cc
4 changed files with 70 additions and 21 deletions

View File

@@ -128,3 +128,13 @@ class MetaTensor(torch.Tensor):
if device is not None:
result = MetaTensor(result, fake_device=device)
return result
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
def cuda(self, *args, **kwargs):
if self.device.type == 'cuda':
return self.to(*args, **kwargs)
return self.to(*args, device='cuda', **kwargs)