[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -7,12 +7,12 @@ import torch
import torch.distributed as dist
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_moe.moe_utils import MoeGradientHandler
@@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_
tp_model (MoeModule)
local_model (MoeModule)
"""
for (tp_name, tp_param), (local_name, local_param) in \
zip(tp_model.named_parameters(), local_model.named_parameters()):
for (tp_name, tp_param), (local_name, local_param) in zip(
tp_model.named_parameters(), local_model.named_parameters()
):
assert tp_name == local_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
@@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
tp_model (MoeModule)
ep_model (MoeModule)
"""
for (tp_name, tp_param), (ep_name, ep_param) in \
zip(tp_model.named_parameters(), ep_model.named_parameters()):
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
assert tp_name == ep_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
@@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in \
zip(local_model.named_parameters(), ep_model.named_parameters()):
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name
if "experts" not in local_name:
if assert_grad_flag:
@@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
num_experts=num_experts,
hidden_size=dim,
intermediate_size=dim * 2,
enable_hierarchical_comm=enable_hierarchical_comm
enable_hierarchical_comm=enable_hierarchical_comm,
)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_current_device())
tp_model = tp_model.to(get_current_device())
local_model = local_model.to(get_current_device())
ep_model = ep_model.to(get_accelerator().get_current_device())
tp_model = tp_model.to(get_accelerator().get_current_device())
local_model = local_model.to(get_accelerator().get_current_device())
# sync ep param
sync_moe_model_param(ep_model)
@@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
tp_grad_handler = MoeGradientHandler(tp_model)
rank = dist.get_rank()
input_data = torch.randn(batch_size, dim, device=get_current_device())
input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
micro_batch_size = batch_size // world_size
index = rank * micro_batch_size
# NOTE: ep & tp takes in sharded data for each process
shard_data = input_data.detach()[index:index + micro_batch_size]
shard_data = input_data.detach()[index : index + micro_batch_size]
out_local = local_model(input_data)
MOE_MANAGER.reset_loss()
@@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
out_ep = ep_model(shard_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(out_tp, out_ep, atol=1e-6), \
f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
assert torch.allclose(
out_tp, out_ep, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
try:
out_local_slice = out_local[index:index + micro_batch_size]
assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \
f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
except AssertionError as e:
out_local_slice = out_local[index : index + micro_batch_size]
assert torch.allclose(
out_ep, out_local_slice, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
except AssertionError:
"""
e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
@@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
"""
warnings.warn(
"EP & TP may result in different behavior from local model. "
"Please check the comments for details."
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
)
out_local.mean().backward()
@@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
try:
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
except AssertionError as e:
except AssertionError:
warnings.warn(
"EP & TP may result in different behavior from local model. "
"Please check the comments for details."
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
)
@@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("config", [
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
])
@pytest.mark.parametrize(
"config",
[
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
],
)
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
if __name__ == '__main__':
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)