[Inference/SpecDec] Support GLIDE Drafter Model (#5455)

* add glide-llama policy and modeling

* update glide modeling, compitable with transformers 4.36.2

* revise glide llama modeling/usage

* fix issues of glimpsing large kv

* revise the way re-loading params for glide drafter

* fix drafter and engine tests

* enable convert to glide strict=False

* revise glide llama modeling

* revise vicuna prompt template

* revise drafter and tests

* apply usage of glide model in engine
This commit is contained in:
Yuanheng Zhao
2024-04-01 21:54:24 +08:00
committed by Yuanheng
parent 912e24b2aa
commit d85d91435a
10 changed files with 722 additions and 82 deletions

View File

@@ -2,18 +2,16 @@ import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import GenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.spec.drafter import Drafter
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
NUM_LAYERS = 2
NUM_LAYERS = 1
MAX_LEN = 100
SPEC_NUM = 5
@pytest.mark.parametrize("spec_num", [5])
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
def test_drafter(spec_num: int):
torch.manual_seed(123)
@@ -41,68 +39,33 @@ def test_drafter(spec_num: int):
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def check_sd():
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=MAX_LEN,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_sd()
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_spec_dec():
spawn(run_dist, nprocs=1)
spec_num = SPEC_NUM
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.eos_token
# Dummy config for Glide Model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=NUM_LAYERS,
)
drafter_model = GlideLlamaForCausalLM(glide_config)
assert hasattr(drafter_model, "model")
assert hasattr(drafter_model.model, "layers")
for _, layer in enumerate(drafter_model.model.layers):
assert hasattr(layer, "cross_attn")
# Init the Drafter by providing the sharded drafter model
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
if __name__ == "__main__":
test_drafter(spec_num=5)
test_drafter(spec_num=SPEC_NUM)
test_spec_dec()

View File

@@ -9,6 +9,7 @@ import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -80,9 +81,81 @@ def check_output_consistency(prompt_template):
FDIntermTensors._instances = {}
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def check_spec_dec(num_layers, max_length):
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=num_layers)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
# test GLIDE model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=num_layers,
)
glide_model = GlideLlamaForCausalLM(glide_config)
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.clear_spec_dec()
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
check_spec_dec()
@pytest.mark.dist