[inference] support only TP (#4998)

* support only tp

* enable tp
This commit is contained in:
Xu Kai
2023-11-01 16:33:30 +08:00
committed by FoolPlayer
parent f71e63b0f3
commit f747d13040
4 changed files with 120 additions and 41 deletions

View File

@@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch
torch.cuda.empty_cache()
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()
def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test()
@@ -75,6 +85,11 @@ def check_tp_pipeline_inference(rank, world_size, port):
run_tp_pipeline_inference_test()
def check_tp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@@ -82,6 +97,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_tp_inference, nprocs=2)
if __name__ == "__main__":