feat: Support llama.cpp

This commit is contained in:
FangYin Cheng
2023-08-15 18:58:15 +08:00
parent dbae09203f
commit b5fd5d2a3a
15 changed files with 652 additions and 73 deletions

View File

@@ -13,7 +13,7 @@ class BaseChatAdpter:
and fetch output from model"""
def match(self, model_path: str):
return True
return False
def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func"""
@@ -24,7 +24,9 @@ class BaseChatAdpter:
def get_conv_template(self, model_path: str) -> Conversation:
return None
def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
def model_adaptation(
self, params: Dict, model_path: str, prompt_template: str = None
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_conv_template(model_path)
messages = params.get("messages")
@@ -39,6 +41,10 @@ class BaseChatAdpter:
]
params["messages"] = messages
if prompt_template:
print(f"Use prompt template {prompt_template} from config")
conv = get_conv_template(prompt_template)
if not conv or not messages:
# Nothing to do
print(
@@ -94,14 +100,19 @@ def register_llm_model_chat_adapter(cls):
@cache
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
def get_llm_chat_adapter(model_name: str, model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
print(f"Get model path: {model_path} adapter {adapter}")
if adapter.match(model_name):
print(f"Get model chat adapter with model name {model_name}, {adapter}")
return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}")
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
print(f"Get model chat adapter with model path {model_path}, {adapter}")
return adapter
raise ValueError(
f"Invalid model for chat adapter with model name {model_name} and model path {model_path}"
)
class VicunaChatAdapter(BaseChatAdpter):
@@ -239,6 +250,24 @@ class WizardLMChatAdapter(BaseChatAdpter):
return get_conv_template("vicuna_v1.1")
class LlamaCppChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
from pilot.model.adapter import LlamaCppAdapater
if "llama-cpp" == model_path:
return True
is_match, _ = LlamaCppAdapater._parse_model_path(model_path)
return is_match
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.llama_cpp_llm import generate_stream
return generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
@@ -248,6 +277,7 @@ register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
register_llm_model_chat_adapter(LlamaCppChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@@ -23,7 +23,7 @@ sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
@@ -34,12 +34,13 @@ class ModelWorker:
def __init__(self, model_path, model_name, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
model_path = _get_model_real_path(model_name, model_path)
# self.model_name = model_name or model_path.split("/")[-1]
self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml: ModelLoader = ModelLoader(
model_path=model_path, model_name=self.model_name
print(
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
)
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
self.model, self.tokenizer = self.ml.loader(
load_8bit=CFG.IS_LOAD_8BIT,
load_4bit=CFG.IS_LOAD_4BIT,
@@ -60,7 +61,7 @@ class ModelWorker:
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
model_path
)
@@ -86,7 +87,7 @@ class ModelWorker:
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS