[Device]Support npu (#6159)

* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-17 15:42:39 +08:00
committed by GitHub
parent e994c64568
commit aaafb38851
18 changed files with 295 additions and 152 deletions

View File

@@ -14,6 +14,8 @@ from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from colossalai.accelerator import get_accelerator
from .stage_manager import PipelineStageManager
@@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
@@ -86,7 +88,7 @@ def _broadcast_object_list(
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device("cuda", get_accelerator().current_device())
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
@@ -139,14 +141,14 @@ def _broadcast_object_list(
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.cuda()
unpickle_object = unpickle_object.to(get_accelerator().current_device())
object_list[i] = unpickle_object
def _check_for_nccl_backend(group):
def _check_for_nccl_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
@@ -154,14 +156,14 @@ def _check_for_nccl_backend(group):
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL
def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device(get_accelerator().current_device())
return current_device, is_nccl_backend
@@ -348,8 +350,11 @@ def _send_recv_serialization_object(
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())
return unpickle_object
@@ -474,9 +479,11 @@ def _p2p_comm(
recv_prev_shape = None
if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)
ops = []
if send_next_shape is not None:
@@ -501,7 +508,7 @@ def _p2p_comm(
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)
ops = []
if tensor_send_next is not None:

View File

@@ -2,7 +2,6 @@ from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@@ -18,7 +17,7 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
def _wait_p2p(wait_handles) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()

View File

@@ -2,7 +2,6 @@ from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map