mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,8 @@ from torch.distributed import ProcessGroup
|
||||
from torch.distributed import distributed_c10d as c10d
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
@@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
if b"cuda" in buf:
|
||||
buf_array = bytearray(buf)
|
||||
device_index = torch.cuda.current_device()
|
||||
device_index = get_accelerator().current_device()
|
||||
# There might be more than one output tensors during forward
|
||||
for cuda_str in re.finditer(b"cuda", buf_array):
|
||||
pos = cuda_str.start()
|
||||
@@ -86,7 +88,7 @@ def _broadcast_object_list(
|
||||
else:
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
current_device = torch.device("cuda", get_accelerator().current_device())
|
||||
|
||||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
@@ -139,14 +141,14 @@ def _broadcast_object_list(
|
||||
# unconsistence in device
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
and unpickle_object.device.index != get_accelerator().current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
unpickle_object = unpickle_object.to(get_accelerator().current_device())
|
||||
|
||||
object_list[i] = unpickle_object
|
||||
|
||||
|
||||
def _check_for_nccl_backend(group):
|
||||
def _check_for_nccl_hccl_backend(group):
|
||||
pg = group or c10d._get_default_group()
|
||||
# Gate PG wrapper check on Gloo availability.
|
||||
if c10d._GLOO_AVAILABLE:
|
||||
@@ -154,14 +156,14 @@ def _check_for_nccl_backend(group):
|
||||
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
||||
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL
|
||||
|
||||
|
||||
def _check_device(group):
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
is_nccl_backend = _check_for_nccl_hccl_backend(group)
|
||||
current_device = torch.device("cpu")
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
current_device = torch.device(get_accelerator().current_device())
|
||||
return current_device, is_nccl_backend
|
||||
|
||||
|
||||
@@ -348,8 +350,11 @@ def _send_recv_serialization_object(
|
||||
|
||||
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
||||
|
||||
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != get_accelerator().current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.to(get_accelerator().current_device())
|
||||
|
||||
return unpickle_object
|
||||
|
||||
@@ -474,9 +479,11 @@ def _p2p_comm(
|
||||
recv_prev_shape = None
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
send_next_shape = torch.tensor(
|
||||
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
|
||||
)
|
||||
if recv_prev:
|
||||
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
|
||||
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)
|
||||
|
||||
ops = []
|
||||
if send_next_shape is not None:
|
||||
@@ -501,7 +508,7 @@ def _p2p_comm(
|
||||
# send and recv data
|
||||
tensor_recv_prev = None
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
|
||||
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)
|
||||
|
||||
ops = []
|
||||
if tensor_send_next is not None:
|
||||
|
@@ -2,7 +2,6 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
@@ -18,7 +17,7 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
|
||||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
||||
def _wait_p2p(wait_handles) -> None:
|
||||
if wait_handles is not None:
|
||||
for req in wait_handles:
|
||||
req.wait()
|
||||
|
@@ -2,7 +2,6 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
Reference in New Issue
Block a user