mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)
* [fix] GQA calling of flash decoding triton * fix kv cache alloc shape * fix rotary triton - GQA * fix sequence max length assigning * Sequence max length logic * fix scheduling and spec-dec * skip without import error * fix pytest - skip without ImportError --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -133,8 +133,9 @@ def check_spec_dec(num_layers, max_length):
|
||||
assert not engine.use_spec_dec
|
||||
assert engine.drafter is None and engine.drafter_model is None
|
||||
|
||||
max_new_tokens = max_length - dummy_inputs.size(1)
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
|
||||
|
||||
# test GLIDE model
|
||||
glide_config = GlideLlamaConfig(
|
||||
@@ -152,7 +153,7 @@ def check_spec_dec(num_layers, max_length):
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
@@ -186,7 +187,7 @@ def test_tp_engine(prompt_template, do_sample):
|
||||
|
||||
|
||||
@parameterize("num_layers", [1])
|
||||
@parameterize("max_length", [100])
|
||||
@parameterize("max_length", [64])
|
||||
def test_spec_dec(num_layers, max_length):
|
||||
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
|
||||
|
||||
|
Reference in New Issue
Block a user