[Fix] Fix Inference Example, Tests, and Requirements (#5688)

* clean requirements

* modify example inference struct

* add test ci scripts

* mark test_infer as submodule

* rm deprecated cls & deps

* import of HAS_FLASH_ATTN

* prune inference tests to be run

* prune triton kernel tests

* increment pytest timeout mins

* revert import path in openmoe
This commit is contained in:
Yuanheng Zhao
2024-05-08 11:30:15 +08:00
committed by GitHub
parent f9afe0addd
commit 55cc7f3df7
23 changed files with 46 additions and 328 deletions

View File

@@ -11,13 +11,16 @@ MAX_LEN = 100
SPEC_NUM = 5
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
def test_drafter(spec_num: int):
def test_drafter(tokenizer, spec_num: int):
torch.manual_seed(123)
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
@@ -39,10 +42,9 @@ def test_drafter(spec_num: int):
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def test_spec_dec():
def test_spec_dec(tokenizer):
spec_num = SPEC_NUM
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.eos_token
# Dummy config for Glide Model
@@ -67,5 +69,6 @@ def test_spec_dec():
if __name__ == "__main__":
test_drafter(spec_num=SPEC_NUM)
test_spec_dec()
dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
test_drafter(dummy_tokenizer, spec_num=SPEC_NUM)
test_spec_dec(dummy_tokenizer)