mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-16 07:24:05 +00:00
feature: add model server proxy
This commit is contained in:
parent
4c60ab1ea2
commit
ea334b172e
@ -96,3 +96,10 @@ VECTOR_STORE_TYPE=Chroma
|
||||
|
||||
LANGUAGE=en
|
||||
#LANGUAGE=zh
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
# ** PROXY_SERVER
|
||||
#*******************************************************************#
|
||||
PROXY_API_KEY=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
|
||||
PROXY_SERVER_URL=http://43.156.9.162:3000/api/openai/v1/chat/completions
|
@ -36,6 +36,10 @@ class Config(metaclass=Singleton):
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
# This is a proxy server, just for test. we will remove this later.
|
||||
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
|
||||
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")
|
||||
|
@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"proxyllm": "proxyllm",
|
||||
}
|
||||
|
||||
# Load model config
|
||||
|
@ -123,9 +123,24 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
class ProxyllmAdapter(BaseLLMAdaper):
|
||||
|
||||
"""The model adapter for local proxy"""
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return "proxyllm", None
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
# just for test, remove this later
|
||||
register_llm_model_adapters(ProxyllmAdapter)
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
||||
|
@ -100,7 +100,7 @@ class ModelLoader(metaclass=Singleton):
|
||||
llm_adapter = get_llm_model_adapter(self.model_path)
|
||||
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
|
||||
|
||||
if load_8bit:
|
||||
if load_8bit and tokenizer:
|
||||
if num_gpus != 1:
|
||||
warnings.warn(
|
||||
"8-bit quantization is not supported for multi-gpu inference"
|
||||
@ -110,7 +110,7 @@ class ModelLoader(metaclass=Singleton):
|
||||
|
||||
if (
|
||||
self.device == "cuda" and num_gpus == 1 and not cpu_offloading
|
||||
) or self.device == "mps":
|
||||
) or self.device == "mps" and tokenizer:
|
||||
model.to(self.device)
|
||||
|
||||
if debug:
|
||||
|
56
pilot/model/proxy_llm.py
Normal file
56
pilot/model/proxy_llm.py
Normal file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import requests
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
CFG = Config()
|
||||
|
||||
def proxyllm_generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
|
||||
history = []
|
||||
|
||||
prompt = params["prompt"]
|
||||
stop = params.get("stop", "###")
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer " + CFG.proxy_api_key
|
||||
}
|
||||
|
||||
messages = prompt.split(stop)
|
||||
|
||||
# Add history conversation
|
||||
for i in range(1, len(messages) - 2, 2):
|
||||
history.append(
|
||||
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
|
||||
)
|
||||
history.append(
|
||||
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]}
|
||||
)
|
||||
|
||||
# Add user query
|
||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||
history.append(
|
||||
{"role": "user", "content": query}
|
||||
)
|
||||
payloads = {
|
||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
}
|
||||
|
||||
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True)
|
||||
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
json_line = json.loads(decoded_line)
|
||||
print(json_line)
|
||||
text += json_line["choices"][0]["message"]["content"]
|
||||
yield text
|
@ -84,7 +84,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
|
||||
pass
|
||||
|
||||
|
||||
class GuanacoAdapter(BaseChatAdpter):
|
||||
class GuanacoChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco """
|
||||
|
||||
def match(self, model_path: str):
|
||||
@ -94,7 +94,20 @@ class GuanacoAdapter(BaseChatAdpter):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.proxy_llm import proxyllm_generate_stream
|
||||
return proxyllm_generate_stream
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
||||
register_llm_model_chat_adapter(BaseChatAdpter)
|
||||
|
@ -37,11 +37,12 @@ class ModelWorker:
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||
)
|
||||
|
||||
if hasattr(self.model.config, "max_sequence_length"):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
if not isinstance(self.model, str):
|
||||
if hasattr(self.model.config, "max_sequence_length"):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
else:
|
||||
self.context_len = 2048
|
||||
|
@ -434,6 +434,7 @@ def http_bot(
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
print("****************:",data)
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user