mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
* add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in engine
This commit is contained in:
45
colossalai/inference/modeling/policy/glide_llama.py
Normal file
45
colossalai/inference/modeling/policy/glide_llama.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.modeling.models.glide_llama import (
|
||||
GlideLlamaDecoderLayer,
|
||||
glide_llama_causal_lm_forward,
|
||||
glide_llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
||||
class GlideLlamaModelPolicy(LlamaForCausalLMPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
num_layers = self.model.config.num_hidden_layers
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix=f"layers[{i}]",
|
||||
target_module=GlideLlamaDecoderLayer,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": glide_llama_model_forward},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": glide_llama_causal_lm_forward},
|
||||
policy=policy,
|
||||
target_key=LlamaForCausalLM,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
for layer in self.model.model.layers:
|
||||
init_to_get_rotary(layer.cross_attn)
|
||||
return self.model
|
Reference in New Issue
Block a user