feature: add model server proxy

This commit is contained in:
csunny 2023-05-30 17:16:29 +08:00
parent 4c60ab1ea2
commit ea334b172e
9 changed files with 106 additions and 8 deletions

View File

@ -96,3 +96,10 @@ VECTOR_STORE_TYPE=Chroma
LANGUAGE=en
#LANGUAGE=zh
#*******************************************************************#
# ** PROXY_SERVER
#*******************************************************************#
PROXY_API_KEY=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
PROXY_SERVER_URL=http://43.156.9.162:3000/api/openai/v1/chat/completions

View File

@ -36,6 +36,10 @@ class Config(metaclass=Singleton):
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
)
# This is a proxy server, just for test. we will remove this later.
self.proxy_api_key = os.getenv("PROXY_API_KEY")
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")

View File

@ -35,6 +35,7 @@ LLM_MODEL_CONFIG = {
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxyllm": "proxyllm",
}
# Load model config

View File

@ -123,9 +123,24 @@ class GPT4AllAdapter(BaseLLMAdaper):
def match(self, model_path: str):
return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO
pass
class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy"""
def match(self, model_path: str):
return "proxyllm" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
return "proxyllm", None
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later
register_llm_model_adapters(ProxyllmAdapter)
register_llm_model_adapters(BaseLLMAdaper)

View File

@ -100,7 +100,7 @@ class ModelLoader(metaclass=Singleton):
llm_adapter = get_llm_model_adapter(self.model_path)
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
if load_8bit:
if load_8bit and tokenizer:
if num_gpus != 1:
warnings.warn(
"8-bit quantization is not supported for multi-gpu inference"
@ -110,7 +110,7 @@ class ModelLoader(metaclass=Singleton):
if (
self.device == "cuda" and num_gpus == 1 and not cpu_offloading
) or self.device == "mps":
) or self.device == "mps" and tokenizer:
model.to(self.device)
if debug:

56
pilot/model/proxy_llm.py Normal file
View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import requests
from pilot.configs.config import Config
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
CFG = Config()
def proxyllm_generate_stream(
model, tokenizer, params, device, context_len=2048
):
history = []
prompt = params["prompt"]
stop = params.get("stop", "###")
headers = {
"Authorization": "Bearer " + CFG.proxy_api_key
}
messages = prompt.split(stop)
# Add history conversation
for i in range(1, len(messages) - 2, 2):
history.append(
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
)
history.append(
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]}
)
# Add user query
query = messages[-2].split(ROLE_USER + ":")[1]
history.append(
{"role": "user", "content": query}
)
payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
}
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True)
text = ""
for line in res.iter_lines():
if line:
decoded_line = line.decode('utf-8')
json_line = json.loads(decoded_line)
print(json_line)
text += json_line["choices"][0]["message"]["content"]
yield text

View File

@ -84,7 +84,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
pass
class GuanacoAdapter(BaseChatAdpter):
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco """
def match(self, model_path: str):
@ -94,7 +94,20 @@ class GuanacoAdapter(BaseChatAdpter):
# TODO
pass
class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
def get_generate_stream_func(self):
from pilot.model.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)

View File

@ -37,11 +37,12 @@ class ModelWorker:
self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048

View File

@ -434,6 +434,7 @@ def http_bot(
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
print("****************:",data)
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else: