mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat: add Bard LLM proxy
Supplement proxy framework content, adding proxy access method for the Bard LLM. Closes #369
This commit is contained in:
parent
dbf8b20c0b
commit
d95726c04a
@ -141,6 +141,7 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
self.PROXY_MODEL = os.getenv("PROXY_MODEL", "chatgpt")
|
||||
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
|
||||
|
@ -1,74 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
|
||||
import json
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
from pilot.model.proxy.proxy_llms.chatgpt import chatgpt_generate_stream
|
||||
from pilot.model.proxy.proxy_llms.bard import bard_generate_stream
|
||||
from pilot.model.proxy.proxy_llms.claude import claude_generate_stream
|
||||
from pilot.model.proxy.proxy_llms.wenxin import wenxin_generate_stream
|
||||
from pilot.model.proxy.proxy_llms.tongyi import tongyi_generate_stream
|
||||
from pilot.model.proxy.proxy_llms.gpt4 import gpt4_generate_stream
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
history = []
|
||||
|
||||
prompt = params["prompt"]
|
||||
stop = params.get("stop", "###")
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer " + CFG.proxy_api_key,
|
||||
"Token": CFG.proxy_api_key,
|
||||
generator_mapping = {
|
||||
"chatgpt": chatgpt_generate_stream,
|
||||
"bard": bard_generate_stream,
|
||||
"claude": claude_generate_stream,
|
||||
"gpt4": gpt4_generate_stream,
|
||||
"wenxin": wenxin_generate_stream,
|
||||
"tongyi": tongyi_generate_stream,
|
||||
}
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
# Move the last user's information to the end
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
payloads = {
|
||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
default_error_message = f"{CFG.PROXY_MODEL} LLM is not supported"
|
||||
generator_function = generator_mapping.get(
|
||||
CFG.PROXY_MODEL, lambda: default_error_message
|
||||
)
|
||||
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield error_message
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["choices"][0]["delta"].get("content") is not None:
|
||||
content = obj["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
||||
yield from generator_function(model, tokenizer, params, device, context_len)
|
||||
|
5
pilot/model/proxy/proxy_llms/__init__.py
Normal file
5
pilot/model/proxy/proxy_llms/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""
|
||||
There are several limitations to privatizing large models: high deployment costs and poor performance.
|
||||
In scenarios where data privacy requirements are relatively low, connecting with commercial large models can enable
|
||||
rapid and efficient product implementation with high quality.
|
||||
"""
|
40
pilot/model/proxy/proxy_llms/bard.py
Normal file
40
pilot/model/proxy/proxy_llms/bard.py
Normal file
@ -0,0 +1,40 @@
|
||||
import bardapi
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
token = CFG.proxy_api_key
|
||||
|
||||
history = []
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
response = bardapi.core.Bard(token).get_answer(last_user_input)
|
||||
if response is not None and response.get("content") is not None:
|
||||
yield str(response["content"])
|
||||
yield f"bard response error: {str(response)}"
|
||||
|
||||
|
||||
print(bard_generate_stream("bard_proxy_llm", None, {"input": "hi"}, None, 2048))
|
70
pilot/model/proxy/proxy_llms/chatgpt.py
Normal file
70
pilot/model/proxy/proxy_llms/chatgpt.py
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
history = []
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer " + CFG.proxy_api_key,
|
||||
"Token": CFG.proxy_api_key,
|
||||
}
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
# Move the last user's information to the end
|
||||
temp_his = history[::-1]
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
payloads = {
|
||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
)
|
||||
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield error_message
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["choices"][0]["delta"].get("content") is not None:
|
||||
content = obj["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
7
pilot/model/proxy/proxy_llms/claude.py
Normal file
7
pilot/model/proxy/proxy_llms/claude.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def claude_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
yield "claude LLM was not supported!"
|
7
pilot/model/proxy/proxy_llms/gpt4.py
Normal file
7
pilot/model/proxy/proxy_llms/gpt4.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def gpt4_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
yield "gpt4 LLM was not supported!"
|
7
pilot/model/proxy/proxy_llms/tongyi.py
Normal file
7
pilot/model/proxy/proxy_llms/tongyi.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def tongyi_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
yield "tongyi LLM was not supported!"
|
7
pilot/model/proxy/proxy_llms/wenxin.py
Normal file
7
pilot/model/proxy/proxy_llms/wenxin.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def wenxin_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
yield "wenxin LLM is not supported!"
|
3
pilot/model/proxy/security/data_mask/__init__.py
Normal file
3
pilot/model/proxy/security/data_mask/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
data masking, transform private sensitive data into mask data, based on the tool sensitive data recognition.
|
||||
"""
|
3
pilot/model/proxy/security/data_mask/masking.py
Normal file
3
pilot/model/proxy/security/data_mask/masking.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
mask the sensitive data before upload LLM inference service
|
||||
"""
|
3
pilot/model/proxy/security/data_mask/recovery.py
Normal file
3
pilot/model/proxy/security/data_mask/recovery.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
recovery the data after LLM inference
|
||||
"""
|
3
pilot/model/proxy/security/sens_data_recognition.py
Normal file
3
pilot/model/proxy/security/sens_data_recognition.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
a tool to discovery sensitive data
|
||||
"""
|
Loading…
Reference in New Issue
Block a user