[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
This commit is contained in:
Hongxin Liu
2023-11-20 16:12:41 +08:00
committed by GitHub
parent 8d56c9c389
commit e5ce4c8ea6
46 changed files with 994 additions and 233 deletions

View File

@@ -2,11 +2,14 @@ from typing import Optional
import torch
import torch.distributed as dist
from torch.optim import Adam
import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.nn.optimizer import HybridAdam
# from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
device = device_utils.get_current_device()
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
@@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
device_utils.empty_cache()
if err is None:
passed_models.append(name)
@@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == "__main__":

View File

@@ -3,7 +3,7 @@ import pytest
import colossalai
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():

View File

@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd

View File

@@ -9,7 +9,7 @@ import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd

View File

@@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd

View File

@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd

View File

@@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context
from colossalai.utils import conditional_context, get_current_device
from colossalai.zero import LowLevelZeroOptimizer
@@ -28,9 +28,9 @@ class MlpModel(nn.Module):
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
device = get_current_device()
# create model
zero1_model = MlpModel().cuda()
zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
@@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc():
)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
@@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank()
seed_all(2008)
device = get_current_device()
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
zero_model = zero_model.to(device)
torch_model = DDP(torch_model.to(device), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
@@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync):
# create data
seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(no_sync, cur_data, check_flag):
# zero1 fwd and bwd