mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -7,12 +7,12 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import sync_moe_model_param
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler
|
||||
|
||||
|
||||
@@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_
|
||||
tp_model (MoeModule)
|
||||
local_model (MoeModule)
|
||||
"""
|
||||
for (tp_name, tp_param), (local_name, local_param) in \
|
||||
zip(tp_model.named_parameters(), local_model.named_parameters()):
|
||||
for (tp_name, tp_param), (local_name, local_param) in zip(
|
||||
tp_model.named_parameters(), local_model.named_parameters()
|
||||
):
|
||||
assert tp_name == local_name
|
||||
if not is_moe_tensor(tp_param):
|
||||
if assert_grad_flag:
|
||||
@@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
|
||||
tp_model (MoeModule)
|
||||
ep_model (MoeModule)
|
||||
"""
|
||||
for (tp_name, tp_param), (ep_name, ep_param) in \
|
||||
zip(tp_model.named_parameters(), ep_model.named_parameters()):
|
||||
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
|
||||
assert tp_name == ep_name
|
||||
if not is_moe_tensor(tp_param):
|
||||
if assert_grad_flag:
|
||||
@@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
|
||||
local_model (MoeModule)
|
||||
ep_model (MoeModule)
|
||||
"""
|
||||
for (local_name, local_param), (ep_name, ep_param) in \
|
||||
zip(local_model.named_parameters(), ep_model.named_parameters()):
|
||||
for (local_name, local_param), (ep_name, ep_param) in zip(
|
||||
local_model.named_parameters(), ep_model.named_parameters()
|
||||
):
|
||||
assert local_name == ep_name
|
||||
if "experts" not in local_name:
|
||||
if assert_grad_flag:
|
||||
@@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
num_experts=num_experts,
|
||||
hidden_size=dim,
|
||||
intermediate_size=dim * 2,
|
||||
enable_hierarchical_comm=enable_hierarchical_comm
|
||||
enable_hierarchical_comm=enable_hierarchical_comm,
|
||||
)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel="TP")
|
||||
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
||||
ep_model = ep_model.to(get_current_device())
|
||||
tp_model = tp_model.to(get_current_device())
|
||||
local_model = local_model.to(get_current_device())
|
||||
ep_model = ep_model.to(get_accelerator().get_current_device())
|
||||
tp_model = tp_model.to(get_accelerator().get_current_device())
|
||||
local_model = local_model.to(get_accelerator().get_current_device())
|
||||
|
||||
# sync ep param
|
||||
sync_moe_model_param(ep_model)
|
||||
@@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
tp_grad_handler = MoeGradientHandler(tp_model)
|
||||
|
||||
rank = dist.get_rank()
|
||||
input_data = torch.randn(batch_size, dim, device=get_current_device())
|
||||
input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
|
||||
micro_batch_size = batch_size // world_size
|
||||
index = rank * micro_batch_size
|
||||
# NOTE: ep & tp takes in sharded data for each process
|
||||
shard_data = input_data.detach()[index:index + micro_batch_size]
|
||||
shard_data = input_data.detach()[index : index + micro_batch_size]
|
||||
|
||||
out_local = local_model(input_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
@@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
out_ep = ep_model(shard_data)
|
||||
MOE_MANAGER.reset_loss()
|
||||
|
||||
assert torch.allclose(out_tp, out_ep, atol=1e-6), \
|
||||
f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
|
||||
assert torch.allclose(
|
||||
out_tp, out_ep, atol=1e-6
|
||||
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
|
||||
try:
|
||||
out_local_slice = out_local[index:index + micro_batch_size]
|
||||
assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \
|
||||
f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
|
||||
except AssertionError as e:
|
||||
out_local_slice = out_local[index : index + micro_batch_size]
|
||||
assert torch.allclose(
|
||||
out_ep, out_local_slice, atol=1e-6
|
||||
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
|
||||
except AssertionError:
|
||||
"""
|
||||
e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
|
||||
router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
|
||||
@@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
|
||||
"""
|
||||
warnings.warn(
|
||||
"EP & TP may result in different behavior from local model. "
|
||||
"Please check the comments for details."
|
||||
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
|
||||
)
|
||||
|
||||
out_local.mean().backward()
|
||||
@@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
|
||||
try:
|
||||
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
|
||||
except AssertionError as e:
|
||||
except AssertionError:
|
||||
warnings.warn(
|
||||
"EP & TP may result in different behavior from local model. "
|
||||
"Please check the comments for details."
|
||||
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
|
||||
)
|
||||
|
||||
|
||||
@@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
||||
@pytest.mark.parametrize("num_experts", [4, 64])
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
@pytest.mark.parametrize("dim", [64])
|
||||
@pytest.mark.parametrize("config", [
|
||||
{"enable_hierarchical_comm": False},
|
||||
{"enable_hierarchical_comm": True},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{"enable_hierarchical_comm": False},
|
||||
{"enable_hierarchical_comm": True},
|
||||
],
|
||||
)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
|
||||
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
|
||||
|
Reference in New Issue
Block a user