[Inference/SpecDec] Support GLIDE Drafter Model (#5455)

* add glide-llama policy and modeling

* update glide modeling, compitable with transformers 4.36.2

* revise glide llama modeling/usage

* fix issues of glimpsing large kv

* revise the way re-loading params for glide drafter

* fix drafter and engine tests

* enable convert to glide strict=False

* revise glide llama modeling

* revise vicuna prompt template

* revise drafter and tests

* apply usage of glide model in engine
This commit is contained in:
Yuanheng Zhao
2024-04-01 21:54:24 +08:00
committed by Yuanheng
parent 912e24b2aa
commit d85d91435a
10 changed files with 722 additions and 82 deletions

View File

@@ -0,0 +1,45 @@
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from colossalai.inference.modeling.models.glide_llama import (
GlideLlamaDecoderLayer,
glide_llama_causal_lm_forward,
glide_llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class GlideLlamaModelPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
num_layers = self.model.config.num_hidden_layers
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix=f"layers[{i}]",
target_module=GlideLlamaDecoderLayer,
)
for i in range(num_layers)
],
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={"forward": glide_llama_model_forward},
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={"forward": glide_llama_causal_lm_forward},
policy=policy,
target_key=LlamaForCausalLM,
)
return policy
def postprocess(self):
for layer in self.model.model.layers:
init_to_get_rotary(layer.cross_attn)
return self.model