[Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code
This commit is contained in:
Haze188
2024-06-12 15:48:52 +08:00
committed by GitHub
parent 8554585a5f
commit d9dddf574f
4 changed files with 12 additions and 12 deletions

View File

@@ -104,9 +104,7 @@ class GeminiDDP(ModelWrapper):
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
)
else:
# some ugly hotfix for the compatibility with Lightning
@@ -122,6 +120,7 @@ class GeminiDDP(ModelWrapper):
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
max_prefetch=max_prefetch,
)
self.gemini_manager = GeminiManager(
placement_policy,