[moe] init moe plugin comm setting with sp

This commit is contained in:
hxwang
2024-07-18 08:37:06 +00:00
committed by Hongxin Liu
parent 09d6280d3e
commit 877d94bb8c
7 changed files with 101 additions and 95 deletions

View File

@@ -23,7 +23,7 @@ NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(1, 1, 1)])
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float16

View File

@@ -24,11 +24,10 @@ NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype = torch.float32
dtype, precision = torch.float16, "fp16"
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
@@ -40,7 +39,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
precision="fp32",
precision=precision,
)
booster = Booster(plugin=plugin)
@@ -109,7 +108,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
dist.barrier()

View File

@@ -26,9 +26,7 @@ top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if loose_close(p1, p2, p1.dtype):
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
loose_close(p1, p2, p1.dtype)
def get_optimizer_snapshot(optim):

View File

@@ -141,12 +141,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
# {
# "tp_size": 1,
# "pp_size": 2,
# "pp_size": 1,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp32",
# "precision": "fp16",
# }, # [dp(4)] + [moe_dp(4)]
# {
# "tp_size": 1,
@@ -169,7 +169,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"sp_size": 2,
"ep_size": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",