mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 04:08:10 +00:00
Merge branch 'llm_proxy' into dev
# Conflicts: # pilot/server/webserver.py
This commit is contained in:
commit
973bcce03c
@ -85,8 +85,10 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
|||||||
|
|
||||||
class GuanacoAdapter(BaseLLMAdaper):
|
class GuanacoAdapter(BaseLLMAdaper):
|
||||||
"""TODO Support guanaco"""
|
"""TODO Support guanaco"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CodeGenAdapter(BaseLLMAdaper):
|
class CodeGenAdapter(BaseLLMAdaper):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -127,9 +129,11 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
|||||||
# TODO
|
# TODO
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ProxyllmAdapter(BaseLLMAdaper):
|
class ProxyllmAdapter(BaseLLMAdaper):
|
||||||
|
|
||||||
"""The model adapter for local proxy"""
|
"""The model adapter for local proxy"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "proxyllm" in model_path
|
return "proxyllm" in model_path
|
||||||
|
|
||||||
|
@ -109,8 +109,10 @@ class ModelLoader(metaclass=Singleton):
|
|||||||
compress_module(model, self.device)
|
compress_module(model, self.device)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.device == "cuda" and num_gpus == 1 and not cpu_offloading
|
(self.device == "cuda" and num_gpus == 1 and not cpu_offloading)
|
||||||
) or self.device == "mps" and tokenizer:
|
or self.device == "mps"
|
||||||
|
and tokenizer
|
||||||
|
):
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
|
@ -8,17 +8,16 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
def proxyllm_generate_stream(
|
|
||||||
model, tokenizer, params, device, context_len=2048
|
|
||||||
):
|
|
||||||
|
|
||||||
|
def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": "Bearer " + CFG.proxy_api_key
|
"Authorization": "Bearer " + CFG.proxy_api_key,
|
||||||
|
"Token": CFG.proxy_api_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = prompt.split(stop)
|
messages = prompt.split(stop)
|
||||||
@ -29,14 +28,15 @@ def proxyllm_generate_stream(
|
|||||||
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
|
{"role": "user", "content": messages[i].split(ROLE_USER + ":")[1]},
|
||||||
)
|
)
|
||||||
history.append(
|
history.append(
|
||||||
{"role": "system", "content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1]}
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": messages[i + 1].split(ROLE_ASSISTANT + ":")[1],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user query
|
# Add user query
|
||||||
query = messages[-2].split(ROLE_USER + ":")[1]
|
query = messages[-2].split(ROLE_USER + ":")[1]
|
||||||
history.append(
|
history.append({"role": "user", "content": query})
|
||||||
{"role": "user", "content": query}
|
|
||||||
)
|
|
||||||
payloads = {
|
payloads = {
|
||||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||||
"messages": history,
|
"messages": history,
|
||||||
@ -44,12 +44,18 @@ def proxyllm_generate_stream(
|
|||||||
"max_tokens": params.get("max_new_tokens"),
|
"max_tokens": params.get("max_new_tokens"),
|
||||||
}
|
}
|
||||||
|
|
||||||
res = requests.post(CFG.proxy_server_url, headers=headers, json=payloads, stream=True)
|
print(payloads)
|
||||||
|
print(headers)
|
||||||
|
res = requests.post(
|
||||||
|
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||||
|
)
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
|
print("====================================res================")
|
||||||
|
print(res)
|
||||||
for line in res.iter_lines():
|
for line in res.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode("utf-8")
|
||||||
json_line = json.loads(decoded_line)
|
json_line = json.loads(decoded_line)
|
||||||
print(json_line)
|
print(json_line)
|
||||||
text += json_line["choices"][0]["message"]["content"]
|
text += json_line["choices"][0]["message"]["content"]
|
||||||
|
@ -85,7 +85,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
|
|
||||||
class GuanacoChatAdapter(BaseChatAdpter):
|
class GuanacoChatAdapter(BaseChatAdpter):
|
||||||
"""Model chat adapter for Guanaco """
|
"""Model chat adapter for Guanaco"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "guanaco" in model_path
|
return "guanaco" in model_path
|
||||||
@ -101,6 +101,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.proxy_llm import proxyllm_generate_stream
|
from pilot.model.proxy_llm import proxyllm_generate_stream
|
||||||
|
|
||||||
return proxyllm_generate_stream
|
return proxyllm_generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user