[gemini] optimize reduce scatter d2h copy (#5760)

* [gemini] optimize reduce scatter d2h copy

* [fix] fix missing reduce variable

* [refactor] remove legacy async reduce scatter code

* [gemini] missing sync

* Revert "[refactor] remove legacy async reduce scatter code"

This reverts commit 58ad76d466.

* [gemini] further optimize with async all reduce

* [fix] pass flag from manager to chunk
This commit is contained in:
botbw
2024-06-05 14:23:13 +08:00
committed by GitHub
parent 10a19e22c6
commit 3f7e3131d9
4 changed files with 52 additions and 62 deletions

View File

@@ -369,6 +369,11 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
if placement_policy == "auto" and enable_async_reduce:
logging.warning(
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
)
pin_memory = True
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),