feat: support bard proxy server

support bard proxy server

Close #384
This commit is contained in:
xuyuan23 2023-07-31 17:49:10 +08:00
parent f57e6c8ba1
commit 1ff4cfdc42

View File

@ -1,4 +1,5 @@
import bardapi
import requests
from typing import List
from pilot.configs.config import Config
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
@ -7,8 +8,6 @@ CFG = Config()
def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
token = CFG.bard_proxy_api_key
history = []
messages: List[ModelMessage] = params["messages"]
for message in messages:
@ -35,8 +34,21 @@ def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
for msg in history:
if msg.get("content"):
msgs.append(msg["content"])
response = bardapi.core.Bard(token).get_answer("\n".join(msgs))
if response is not None and response.get("content") is not None:
yield str(response["content"])
if CFG.proxy_server_url is not None:
headers = {"Content-Type": "application/json"}
payloads = {"input": "\n".join(msgs)}
response = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=False
)
if response.ok is True:
yield response.text
else:
yield f"bard proxy url request failed!, response = {str(response)}"
else:
yield f"bard response error: {str(response)}"
response = bardapi.core.Bard(CFG.bard_proxy_api_key).get_answer("\n".join(msgs))
if response is not None and response.get("content") is not None:
yield str(response["content"])
else:
yield f"bard response error: {str(response)}"