[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -157,7 +157,7 @@ class GeminiDDP(ModelWrapper):
self.enable_async_reduce = enable_async_reduce
if enable_async_reduce:
self.async_reduce_stream = torch.cuda.Stream()
self.async_reduce_stream = get_accelerator().Stream()
else:
self.async_reduce_stream = None
@@ -363,7 +363,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool,
enable_gradient_accumulation: bool,
p: nn.Parameter,
async_reduce_stream: Optional[torch.cuda.Stream] = None,
async_reduce_stream=None,
):
async_reduce_scatter = async_reduce_stream is not None
setattr(p, "_gemini_reduced", True)
@@ -402,9 +402,9 @@ class GeminiDDP(ModelWrapper):
grad_chunk.add_tensor_to_chunk_slice(p, grad)
if async_reduce_stream is not None:
async_reduce_stream.wait_stream(torch.cuda.current_stream())
async_reduce_stream.wait_stream(get_accelerator().current_stream())
with torch.cuda.stream(async_reduce_stream):
with get_accelerator().stream(async_reduce_stream):
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
if reduced:
grad_chunk.wait_async_reduce()