mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-25 14:54:26 +00:00
feat: Support llama.cpp
This commit is contained in:
@@ -13,7 +13,7 @@ class BaseChatAdpter:
|
||||
and fetch output from model"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
"""Return the generate stream handler func"""
|
||||
@@ -24,7 +24,9 @@ class BaseChatAdpter:
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return None
|
||||
|
||||
def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
|
||||
def model_adaptation(
|
||||
self, params: Dict, model_path: str, prompt_template: str = None
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
conv = self.get_conv_template(model_path)
|
||||
messages = params.get("messages")
|
||||
@@ -39,6 +41,10 @@ class BaseChatAdpter:
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
if prompt_template:
|
||||
print(f"Use prompt template {prompt_template} from config")
|
||||
conv = get_conv_template(prompt_template)
|
||||
|
||||
if not conv or not messages:
|
||||
# Nothing to do
|
||||
print(
|
||||
@@ -94,14 +100,19 @@ def register_llm_model_chat_adapter(cls):
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
||||
def get_llm_chat_adapter(model_name: str, model_path: str) -> BaseChatAdpter:
|
||||
"""Get a chat generate func for a model"""
|
||||
for adapter in llm_model_chat_adapters:
|
||||
if adapter.match(model_path):
|
||||
print(f"Get model path: {model_path} adapter {adapter}")
|
||||
if adapter.match(model_name):
|
||||
print(f"Get model chat adapter with model name {model_name}, {adapter}")
|
||||
return adapter
|
||||
|
||||
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
||||
for adapter in llm_model_chat_adapters:
|
||||
if adapter.match(model_path):
|
||||
print(f"Get model chat adapter with model path {model_path}, {adapter}")
|
||||
return adapter
|
||||
raise ValueError(
|
||||
f"Invalid model for chat adapter with model name {model_name} and model path {model_path}"
|
||||
)
|
||||
|
||||
|
||||
class VicunaChatAdapter(BaseChatAdpter):
|
||||
@@ -239,6 +250,24 @@ class WizardLMChatAdapter(BaseChatAdpter):
|
||||
return get_conv_template("vicuna_v1.1")
|
||||
|
||||
|
||||
class LlamaCppChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
from pilot.model.adapter import LlamaCppAdapater
|
||||
|
||||
if "llama-cpp" == model_path:
|
||||
return True
|
||||
is_match, _ = LlamaCppAdapater._parse_model_path(model_path)
|
||||
return is_match
|
||||
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("llama-2")
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.llama_cpp_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||
@@ -248,6 +277,7 @@ register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
||||
register_llm_model_chat_adapter(LlamaCppChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
||||
@@ -23,7 +23,7 @@ sys.path.append(ROOT_PATH)
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
|
||||
@@ -34,12 +34,13 @@ class ModelWorker:
|
||||
def __init__(self, model_path, model_name, device):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_name = model_name or model_path.split("/")[-1]
|
||||
model_path = _get_model_real_path(model_name, model_path)
|
||||
# self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||
self.ml: ModelLoader = ModelLoader(
|
||||
model_path=model_path, model_name=self.model_name
|
||||
print(
|
||||
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
|
||||
)
|
||||
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
load_8bit=CFG.IS_LOAD_8BIT,
|
||||
load_4bit=CFG.IS_LOAD_4BIT,
|
||||
@@ -60,7 +61,7 @@ class ModelWorker:
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
model_path
|
||||
)
|
||||
@@ -86,7 +87,7 @@ class ModelWorker:
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
|
||||
Reference in New Issue
Block a user