[tensor] ZeRO use ColoTensor as the base class. (#828)

* [refactor] moving InsertPostInitMethodToModuleSubClasses to utils.

* [tensor] ZeRO use ColoTensor as the base class.

* polish
This commit is contained in:
Jiarui Fang
2022-04-22 12:00:48 +08:00
committed by GitHub
parent 8e6fdb4f29
commit 294a6060d0
4 changed files with 36 additions and 26 deletions

View File

@@ -60,8 +60,8 @@ def test_no_wrap_op():
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.torch_tensor().numel() == 6
def check_all():