[Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code
This commit is contained in:
Haze188
2024-06-12 15:48:52 +08:00
committed by GitHub
parent 8554585a5f
commit d9dddf574f
4 changed files with 12 additions and 12 deletions

View File

@@ -25,6 +25,7 @@ class ChunkManager:
chunk_configuration,
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
@@ -42,6 +43,7 @@ class ChunkManager:
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
def register_tensor(
self,