feat: Support llama.cpp

This commit is contained in:
FangYin Cheng
2023-08-15 18:58:15 +08:00
parent dbae09203f
commit b5fd5d2a3a
15 changed files with 652 additions and 73 deletions

View File

@@ -23,7 +23,7 @@ sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
@@ -34,12 +34,13 @@ class ModelWorker:
def __init__(self, model_path, model_name, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
model_path = _get_model_real_path(model_name, model_path)
# self.model_name = model_name or model_path.split("/")[-1]
self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml: ModelLoader = ModelLoader(
model_path=model_path, model_name=self.model_name
print(
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
)
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
self.model, self.tokenizer = self.ml.loader(
load_8bit=CFG.IS_LOAD_8BIT,
load_4bit=CFG.IS_LOAD_4BIT,
@@ -60,7 +61,7 @@ class ModelWorker:
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
model_path
)
@@ -86,7 +87,7 @@ class ModelWorker:
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS