mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 19:04:24 +00:00
Merge remote-tracking branch 'origin/feature-xuyuan-openai-proxy' into ty_test
# Conflicts: # pilot/model/llm_out/proxy_llm.py
This commit is contained in:
commit
4e5ce4d98b
@ -7,6 +7,10 @@
|
|||||||
## For example, to disable coding related features, uncomment the next line
|
## For example, to disable coding related features, uncomment the next line
|
||||||
# DISABLED_COMMAND_CATEGORIES=
|
# DISABLED_COMMAND_CATEGORIES=
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** Webserver Port **#
|
||||||
|
#*******************************************************************#
|
||||||
|
WEB_SERVER_PORT=7860
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
#*** LLM PROVIDER ***#
|
#*** LLM PROVIDER ***#
|
||||||
@ -17,6 +21,7 @@
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
#** LLM MODELS **#
|
#** LLM MODELS **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
|
# LLM_MODEL, see /pilot/configs/model_config.LLM_MODEL_CONFIG
|
||||||
LLM_MODEL=vicuna-13b
|
LLM_MODEL=vicuna-13b
|
||||||
MODEL_SERVER=http://127.0.0.1:8000
|
MODEL_SERVER=http://127.0.0.1:8000
|
||||||
LIMIT_MODEL_CONCURRENCY=5
|
LIMIT_MODEL_CONCURRENCY=5
|
||||||
@ -98,15 +103,20 @@ VECTOR_STORE_TYPE=Chroma
|
|||||||
#MILVUS_SECURE=
|
#MILVUS_SECURE=
|
||||||
|
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** WebServer Language Support **#
|
||||||
|
#*******************************************************************#
|
||||||
LANGUAGE=en
|
LANGUAGE=en
|
||||||
#LANGUAGE=zh
|
#LANGUAGE=zh
|
||||||
|
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# ** PROXY_SERVER
|
# ** PROXY_SERVER (openai interface | chatGPT proxy service), use chatGPT as your LLM.
|
||||||
|
# ** if your server can visit openai, please set PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
||||||
|
# ** else if you have a chatgpt proxy server, you can set PROXY_SERVER_URL={your-proxy-serverip:port/xxx}
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
PROXY_API_KEY=
|
PROXY_API_KEY={your-openai-sk}
|
||||||
PROXY_SERVER_URL=http://127.0.0.1:3000/proxy_address
|
PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
||||||
|
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
|
@ -17,8 +17,9 @@ class Config(metaclass=Singleton):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the Config class"""
|
"""Initialize the Config class"""
|
||||||
|
|
||||||
# Gradio language version: en, cn
|
# Gradio language version: en, zh
|
||||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||||
|
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
||||||
|
|
||||||
self.debug_mode = False
|
self.debug_mode = False
|
||||||
self.skip_reprompt = False
|
self.skip_reprompt = False
|
||||||
|
@ -62,10 +62,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
|
|
||||||
payloads = {
|
payloads = {
|
||||||
"model": "gpt-3.5-turbo", # just for test_py, remove this later
|
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||||
"messages": history,
|
"messages": history,
|
||||||
"temperature": params.get("temperature"),
|
"temperature": params.get("temperature"),
|
||||||
"max_tokens": params.get("max_new_tokens"),
|
"max_tokens": params.get("max_new_tokens"),
|
||||||
|
"stream": True
|
||||||
}
|
}
|
||||||
|
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
@ -75,14 +76,32 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
|||||||
text = ""
|
text = ""
|
||||||
for line in res.iter_lines():
|
for line in res.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode("utf-8")
|
json_data = line.split(b': ', 1)[1]
|
||||||
try:
|
decoded_line = json_data.decode("utf-8")
|
||||||
json_line = json.loads(decoded_line)
|
if decoded_line.lower() != '[DONE]'.lower():
|
||||||
print(json_line)
|
obj = json.loads(json_data)
|
||||||
text += json_line["choices"][0]["message"]["content"]
|
if obj['choices'][0]['delta'].get('content') is not None:
|
||||||
|
content = obj['choices'][0]['delta']['content']
|
||||||
|
text += content
|
||||||
yield text
|
yield text
|
||||||
except Exception as e:
|
|
||||||
text += decoded_line
|
|
||||||
yield json.loads(text)["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
|
|
||||||
|
# native result.
|
||||||
|
# payloads = {
|
||||||
|
# "model": "gpt-3.5-turbo", # just for test, remove this later
|
||||||
|
# "messages": history,
|
||||||
|
# "temperature": params.get("temperature"),
|
||||||
|
# "max_tokens": params.get("max_new_tokens"),
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# res = requests.post(
|
||||||
|
# CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# text = ""
|
||||||
|
# line = res.content
|
||||||
|
# if line:
|
||||||
|
# decoded_line = line.decode("utf-8")
|
||||||
|
# json_line = json.loads(decoded_line)
|
||||||
|
# print(json_line)
|
||||||
|
# text += json_line["choices"][0]["message"]["content"]
|
||||||
|
# yield text
|
@ -84,6 +84,11 @@ class ModelWorker:
|
|||||||
return get_embeddings(self.model, self.tokenizer, prompt)
|
return get_embeddings(self.model, self.tokenizer, prompt)
|
||||||
|
|
||||||
|
|
||||||
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
|
worker = ModelWorker(
|
||||||
|
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@ -157,11 +162,4 @@ def embeddings(prompt_request: EmbeddingRequest):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
|
||||||
print(model_path, DEVICE)
|
|
||||||
|
|
||||||
worker = ModelWorker(
|
|
||||||
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
|
||||||
)
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")
|
||||||
|
@ -658,7 +658,7 @@ def signal_handler(sig, frame):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
parser.add_argument("--port", type=int)
|
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||||
|
Loading…
Reference in New Issue
Block a user