[Fix] Fix Inference Example, Tests, and Requirements (#5688)

* clean requirements

* modify example inference struct

* add test ci scripts

* mark test_infer as submodule

* rm deprecated cls & deps

* import of HAS_FLASH_ATTN

* prune inference tests to be run

* prune triton kernel tests

* increment pytest timeout mins

* revert import path in openmoe
This commit is contained in:
Yuanheng Zhao
2024-05-08 11:30:15 +08:00
committed by GitHub
parent f9afe0addd
commit 55cc7f3df7
23 changed files with 46 additions and 328 deletions

View File

@@ -165,8 +165,10 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
func_to_run(**kwargs)
@pytest.mark.largedist
@parameterize("prompt_template", [None, "llama"])
@parameterize("do_sample", [False])
@rerun_if_address_is_in_use()
def test_tp_engine(prompt_template, do_sample):
kwargs1 = {
"use_engine": True,
@@ -186,18 +188,14 @@ def test_tp_engine(prompt_template, do_sample):
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
@pytest.mark.largedist
@parameterize("num_layers", [1])
@parameterize("max_length", [64])
@rerun_if_address_is_in_use()
def test_spec_dec(num_layers, max_length):
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
if __name__ == "__main__":
test_tp_engine()
test_spec_dec()
if __name__ == "__main__":
test_inference_engine()