mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 15:10:14 +00:00
feat: Support llama.cpp
This commit is contained in:
@@ -23,7 +23,7 @@ sys.path.append(ROOT_PATH)
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import *
|
||||
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
|
||||
from pilot.model.loader import ModelLoader
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
|
||||
@@ -34,12 +34,13 @@ class ModelWorker:
|
||||
def __init__(self, model_path, model_name, device):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_name = model_name or model_path.split("/")[-1]
|
||||
model_path = _get_model_real_path(model_name, model_path)
|
||||
# self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||
self.ml: ModelLoader = ModelLoader(
|
||||
model_path=model_path, model_name=self.model_name
|
||||
print(
|
||||
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
|
||||
)
|
||||
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
load_8bit=CFG.IS_LOAD_8BIT,
|
||||
load_4bit=CFG.IS_LOAD_4BIT,
|
||||
@@ -60,7 +61,7 @@ class ModelWorker:
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
model_path
|
||||
)
|
||||
@@ -86,7 +87,7 @@ class ModelWorker:
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
|
Reference in New Issue
Block a user