[Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code
This commit is contained in:
Haze188
2024-06-12 15:48:52 +08:00
committed by GitHub
parent 8554585a5f
commit d9dddf574f
4 changed files with 12 additions and 12 deletions

View File

@@ -5,6 +5,7 @@ from typing import List
import torch
from colossalai.accelerator import get_accelerator
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
@@ -54,10 +55,11 @@ class GeminiZeROHook(ColoParamOpHook):
)
# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()