feat(model): Support Phi-3 models (#1554)

This commit is contained in:
Fangyin Cheng 2024-05-23 09:45:32 +08:00 committed by GitHub
parent 47430f2a0b
commit 7f55aa4b6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 74 additions and 0 deletions

View File

@ -158,6 +158,7 @@ At present, we have introduced several key features to showcase our current capa
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
- News
- 🔥🔥🔥 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- 🔥🔥🔥 [Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)
- 🔥🔥🔥 [Yi-1.5-9B-Chat](https://huggingface.co/01-ai/Yi-1.5-9B-Chat)
- 🔥🔥🔥 [Yi-1.5-6B-Chat](https://huggingface.co/01-ai/Yi-1.5-6B-Chat)

View File

@ -152,6 +152,7 @@
海量模型支持包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
- 新增支持模型
- 🔥🔥🔥 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- 🔥🔥🔥 [Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)
- 🔥🔥🔥 [Yi-1.5-9B-Chat](https://huggingface.co/01-ai/Yi-1.5-9B-Chat)
- 🔥🔥🔥 [Yi-1.5-6B-Chat](https://huggingface.co/01-ai/Yi-1.5-6B-Chat)

View File

@ -187,6 +187,16 @@ LLM_MODEL_CONFIG = {
"gemma-2b-it": os.path.join(MODEL_PATH, "gemma-2b-it"),
"starling-lm-7b-beta": os.path.join(MODEL_PATH, "Starling-LM-7B-beta"),
"deepseek-v2-lite-chat": os.path.join(MODEL_PATH, "DeepSeek-V2-Lite-Chat"),
"sailor-14b-chat": os.path.join(MODEL_PATH, "Sailor-14B-Chat"),
# https://huggingface.co/microsoft/Phi-3-medium-128k-instruct
"phi-3-medium-128k-instruct": os.path.join(
MODEL_PATH, "Phi-3-medium-128k-instruct"
),
"phi-3-medium-4k-instruct": os.path.join(MODEL_PATH, "Phi-3-medium-4k-instruct"),
"phi-3-small-128k-instruct": os.path.join(MODEL_PATH, "Phi-3-small-128k-instruct"),
"phi-3-small-8k-instruct": os.path.join(MODEL_PATH, "Phi-3-small-8k-instruct"),
"phi-3-mini-128k-instruct": os.path.join(MODEL_PATH, "Phi-3-mini-128k-instruct"),
"phi-3-mini-4k-instruct": os.path.join(MODEL_PATH, "Phi-3-mini-4k-instruct"),
}
EMBEDDING_MODEL_CONFIG = {

View File

@ -396,6 +396,61 @@ class DeepseekV2Adapter(NewHFChatModelAdapter):
return model, tokenizer
class SailorAdapter(QwenAdapter):
"""
https://huggingface.co/sail/Sailor-14B-Chat
"""
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "sailor" in lower_model_name_or_path
and "chat" in lower_model_name_or_path
)
class PhiAdapter(NewHFChatModelAdapter):
"""
https://huggingface.co/microsoft/Phi-3-medium-128k-instruct
"""
support_4bit: bool = True
support_8bit: bool = True
support_system_message: bool = False
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return (
lower_model_name_or_path
and "phi-3" in lower_model_name_or_path
and "instruct" in lower_model_name_or_path
)
def load(self, model_path: str, from_pretrained_kwargs: dict):
if not from_pretrained_kwargs:
from_pretrained_kwargs = {}
if "trust_remote_code" not in from_pretrained_kwargs:
from_pretrained_kwargs["trust_remote_code"] = True
return super().load(model_path, from_pretrained_kwargs)
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
str_prompt = super().get_str_prompt(
params,
messages,
tokenizer,
prompt_template,
convert_to_compatible_format,
)
params["custom_stop_words"] = ["<|end|>"]
return str_prompt
# The following code is used to register the model adapter
# The last registered model adapter is matched first
register_model_adapter(YiAdapter)
@ -408,3 +463,5 @@ register_model_adapter(QwenAdapter)
register_model_adapter(QwenMoeAdapter)
register_model_adapter(Llama3Adapter)
register_model_adapter(DeepseekV2Adapter)
register_model_adapter(SailorAdapter)
register_model_adapter(PhiAdapter)

View File

@ -22,6 +22,7 @@ def huggingface_chat_generate_stream(
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", None)
custom_stop_words = params.get("custom_stop_words", [])
input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
@ -62,4 +63,8 @@ def huggingface_chat_generate_stream(
out = ""
for new_text in streamer:
out += new_text
if custom_stop_words:
for stop_word in custom_stop_words:
if out.endswith(stop_word):
out = out[: -len(stop_word)]
yield out