[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
This commit is contained in:
ver217
2022-06-10 14:48:28 +08:00
committed by GitHub
parent 2b2dc1c86b
commit 1f894e033f
9 changed files with 366 additions and 12 deletions

View File

@@ -178,6 +178,9 @@ class Chunk:
def __eq__(self, __o: object) -> bool:
return self is __o
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
class ChunkManager:
@@ -234,6 +237,10 @@ class ChunkManager:
def access_chunk(self, chunk: Chunk) -> None:
if chunk in self.accessed_chunks:
if chunk.device_type != 'cuda':
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(get_current_device())
self.total_mem[chunk.device_type] += chunk.mem
return
if not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem