mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -157,7 +157,7 @@ class GeminiDDP(ModelWrapper):
|
||||
self.enable_async_reduce = enable_async_reduce
|
||||
|
||||
if enable_async_reduce:
|
||||
self.async_reduce_stream = torch.cuda.Stream()
|
||||
self.async_reduce_stream = get_accelerator().Stream()
|
||||
else:
|
||||
self.async_reduce_stream = None
|
||||
|
||||
@@ -363,7 +363,7 @@ class GeminiDDP(ModelWrapper):
|
||||
master_weights: bool,
|
||||
enable_gradient_accumulation: bool,
|
||||
p: nn.Parameter,
|
||||
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
||||
async_reduce_stream=None,
|
||||
):
|
||||
async_reduce_scatter = async_reduce_stream is not None
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
@@ -402,9 +402,9 @@ class GeminiDDP(ModelWrapper):
|
||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||
|
||||
if async_reduce_stream is not None:
|
||||
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
||||
async_reduce_stream.wait_stream(get_accelerator().current_stream())
|
||||
|
||||
with torch.cuda.stream(async_reduce_stream):
|
||||
with get_accelerator().stream(async_reduce_stream):
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
|
||||
if reduced:
|
||||
grad_chunk.wait_async_reduce()
|
||||
|
Reference in New Issue
Block a user