mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -11,7 +11,7 @@ import torch.distributed as dist
|
||||
|
||||
from colossalai.context import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import set_device, set_seed
|
||||
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
|
||||
|
||||
|
||||
def launch(
|
||||
@@ -47,12 +47,15 @@ def launch(
|
||||
if rank == 0:
|
||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||
|
||||
if IS_NPU_AVAILABLE and backend == "nccl":
|
||||
backend = "hccl"
|
||||
|
||||
# init default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
|
||||
# if local rank is not given, calculate automatically
|
||||
set_device(local_rank)
|
||||
|
||||
|
Reference in New Issue
Block a user