mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[fx] add more meta_registry for MetaTensor execution. (#2000)
* [sc] add examples for auto checkpoint. * merge upstream * [fx] add more meta_registry for MetaTensor execution.
This commit is contained in:
@@ -128,3 +128,13 @@ class MetaTensor(torch.Tensor):
|
||||
if device is not None:
|
||||
result = MetaTensor(result, fake_device=device)
|
||||
return result
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
if self.device.type == 'cuda':
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cuda', **kwargs)
|
||||
|
Reference in New Issue
Block a user