[zero] stateful tensor manager (#687)

* [WIP] stateful tensor manager

* add eviction strategy

* polish code

* polish code

* polish comment

* add unit test

* fix sampler bug

* polish code

* fix max sampling cnt resetting bug

* fix sampler bug

* polish code

* fix bug

* fix unit test

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
ver217
2022-04-08 17:51:34 +08:00
committed by GitHub
parent 70e8dd418b
commit 3c9cd5bb5e
8 changed files with 271 additions and 73 deletions

View File

@@ -26,7 +26,7 @@ class ShardedParamV2(object):
def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class.
"""
return [self._sharded_data_tensor, self.saved_grad]
return [self._sharded_data_tensor]
def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)