mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
* add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in engine
This commit is contained in:
@@ -2,18 +2,16 @@ import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import GenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||
from colossalai.inference.spec.drafter import Drafter
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
NUM_LAYERS = 2
|
||||
NUM_LAYERS = 1
|
||||
MAX_LEN = 100
|
||||
SPEC_NUM = 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec_num", [5])
|
||||
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
||||
def test_drafter(spec_num: int):
|
||||
torch.manual_seed(123)
|
||||
|
||||
@@ -41,68 +39,33 @@ def test_drafter(spec_num: int):
|
||||
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||
|
||||
|
||||
def check_sd():
|
||||
torch.manual_seed(123)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# Dummy configs for testing
|
||||
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||
drafter_model = LlamaForCausalLM(toy_config)
|
||||
drafter_model = drafter_model.eval().cuda()
|
||||
large_config = LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=8,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
)
|
||||
large_config.pad_token_id = tokenizer.eos_token_id
|
||||
main_model = LlamaForCausalLM(large_config)
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype="fp16",
|
||||
micro_batch_size=1,
|
||||
max_batch_size=1,
|
||||
max_input_len=128,
|
||||
max_output_len=128,
|
||||
prefill_ratio=1.2,
|
||||
block_size=16,
|
||||
)
|
||||
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
|
||||
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
max_length=MAX_LEN,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
out, out_token_ids = engine.generate(
|
||||
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
engine.disable_spec_dec()
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert not engine.use_spec_dec
|
||||
assert engine.drafter is None and engine.drafter_model is None
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_sd()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_spec_dec():
|
||||
spawn(run_dist, nprocs=1)
|
||||
spec_num = SPEC_NUM
|
||||
device = get_current_device()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Dummy config for Glide Model
|
||||
glide_config = GlideLlamaConfig(
|
||||
intermediate_size=8192,
|
||||
large_hidden_size=4096,
|
||||
large_num_attention_heads=32,
|
||||
num_hidden_layers=NUM_LAYERS,
|
||||
)
|
||||
drafter_model = GlideLlamaForCausalLM(glide_config)
|
||||
|
||||
assert hasattr(drafter_model, "model")
|
||||
assert hasattr(drafter_model.model, "layers")
|
||||
for _, layer in enumerate(drafter_model.model.layers):
|
||||
assert hasattr(layer, "cross_attn")
|
||||
|
||||
# Init the Drafter by providing the sharded drafter model
|
||||
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
|
||||
|
||||
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
||||
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_drafter(spec_num=5)
|
||||
test_drafter(spec_num=SPEC_NUM)
|
||||
test_spec_dec()
|
||||
|
@@ -9,6 +9,7 @@ import colossalai
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@@ -80,9 +81,81 @@ def check_output_consistency(prompt_template):
|
||||
FDIntermTensors._instances = {}
|
||||
|
||||
|
||||
@parameterize("num_layers", [1])
|
||||
@parameterize("max_length", [100])
|
||||
def check_spec_dec(num_layers, max_length):
|
||||
torch.manual_seed(123)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# Dummy configs for testing
|
||||
toy_config = LlamaConfig(num_hidden_layers=num_layers)
|
||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||
drafter_model = LlamaForCausalLM(toy_config)
|
||||
drafter_model = drafter_model.eval().cuda()
|
||||
large_config = LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=8,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
)
|
||||
large_config.pad_token_id = tokenizer.eos_token_id
|
||||
main_model = LlamaForCausalLM(large_config)
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype="fp16",
|
||||
micro_batch_size=1,
|
||||
max_batch_size=1,
|
||||
max_input_len=128,
|
||||
max_output_len=128,
|
||||
prefill_ratio=1.2,
|
||||
block_size=16,
|
||||
)
|
||||
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
|
||||
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
max_length=max_length,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
out, out_token_ids = engine.generate(
|
||||
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
engine.disable_spec_dec()
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert not engine.use_spec_dec
|
||||
assert engine.drafter is None and engine.drafter_model is None
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
|
||||
# test GLIDE model
|
||||
glide_config = GlideLlamaConfig(
|
||||
intermediate_size=8192,
|
||||
large_hidden_size=4096,
|
||||
large_num_attention_heads=32,
|
||||
num_hidden_layers=num_layers,
|
||||
)
|
||||
glide_model = GlideLlamaForCausalLM(glide_config)
|
||||
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
|
||||
|
||||
out, out_token_ids = engine.generate(
|
||||
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency()
|
||||
check_spec_dec()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user