mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] stateful tensor manager (#687)
* [WIP] stateful tensor manager * add eviction strategy * polish code * polish code * polish comment * add unit test * fix sampler bug * polish code * fix max sampling cnt resetting bug * fix sampler bug * polish code * fix bug * fix unit test Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ class ShardedParamV2(object):
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor, self.saved_grad]
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
Reference in New Issue
Block a user