mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe
This commit is contained in:
@@ -11,13 +11,16 @@ MAX_LEN = 100
|
||||
SPEC_NUM = 5
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
||||
def test_drafter(spec_num: int):
|
||||
def test_drafter(tokenizer, spec_num: int):
|
||||
torch.manual_seed(123)
|
||||
|
||||
device = get_current_device()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||
drafter_model = LlamaForCausalLM(toy_config)
|
||||
@@ -39,10 +42,9 @@ def test_drafter(spec_num: int):
|
||||
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||
|
||||
|
||||
def test_spec_dec():
|
||||
def test_spec_dec(tokenizer):
|
||||
spec_num = SPEC_NUM
|
||||
device = get_current_device()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Dummy config for Glide Model
|
||||
@@ -67,5 +69,6 @@ def test_spec_dec():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_drafter(spec_num=SPEC_NUM)
|
||||
test_spec_dec()
|
||||
dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
test_drafter(dummy_tokenizer, spec_num=SPEC_NUM)
|
||||
test_spec_dec(dummy_tokenizer)
|
||||
|
Reference in New Issue
Block a user