mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[moe] init moe plugin comm setting with sp
This commit is contained in:
@@ -23,7 +23,7 @@ NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(1, 1, 1)])
|
||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype = torch.float16
|
||||
|
@@ -24,11 +24,10 @@ NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype = torch.float32
|
||||
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
@@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
zero_stage=stage,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision="fp32",
|
||||
precision=precision,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
|
||||
dist.barrier()
|
||||
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda()
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||
check_model_equal(torch_model, saved_model)
|
||||
|
||||
dist.barrier()
|
||||
|
@@ -26,9 +26,7 @@ top_k = 2
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
if loose_close(p1, p2, p1.dtype):
|
||||
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
|
||||
raise AssertionError(f"Model parameter {name} is not equal")
|
||||
loose_close(p1, p2, p1.dtype)
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
|
@@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
[
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "pp_size": 1,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "zero_stage": 0,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# "precision": "fp16",
|
||||
# }, # [dp(4)] + [moe_dp(4)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
@@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 4,
|
||||
"sp_size": 2,
|
||||
"ep_size": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
|
Reference in New Issue
Block a user