diff --git a/docs/modules/llms.md b/docs/modules/llms.md index dd7cbc064..474f7bb9a 100644 --- a/docs/modules/llms.md +++ b/docs/modules/llms.md @@ -128,9 +128,17 @@ PROXY_SERVER_URL={your-openai-proxy-server/v1/chat/completions} ``` ### 2. Bard Proxy -- If your environment deploying DB-GPT has access to https://bard.google.com/ (F12-> application-> __Secure-1PSID), then modify the .env configuration file as below will work. +- If your environment deploying DB-GPT has access to Bard (F12-> application-> __Secure-1PSID), then modify the .env configuration file as below will work. ``` LLM_MODEL=bard_proxyllm MODEL_SERVER=127.0.0.1:8000 BARD_PROXY_API_KEY={your-bard-key} +# PROXY_SERVER_URL={your-bard-proxy-server/v1/chat/completions} +``` + +- If you want to use your own bard proxy server like Bard-Proxy, so that you can deploy DB-GPT on your PC easily. +``` +LLM_MODEL=bard_proxyllm +MODEL_SERVER=127.0.0.1:8000 +PROXY_SERVER_URL={your-bard-proxy-server/v1/chat/completions} ``` \ No newline at end of file diff --git a/pilot/model/proxy/llms/bard.py b/pilot/model/proxy/llms/bard.py index badb0e912..73f959512 100644 --- a/pilot/model/proxy/llms/bard.py +++ b/pilot/model/proxy/llms/bard.py @@ -1,4 +1,5 @@ import bardapi +import requests from typing import List from pilot.configs.config import Config from pilot.scene.base_message import ModelMessage, ModelMessageRoleType @@ -7,8 +8,6 @@ CFG = Config() def bard_generate_stream(model, tokenizer, params, device, context_len=2048): - token = CFG.bard_proxy_api_key - history = [] messages: List[ModelMessage] = params["messages"] for message in messages: @@ -35,8 +34,21 @@ def bard_generate_stream(model, tokenizer, params, device, context_len=2048): for msg in history: if msg.get("content"): msgs.append(msg["content"]) - response = bardapi.core.Bard(token).get_answer("\n".join(msgs)) - if response is not None and response.get("content") is not None: - yield str(response["content"]) + + if CFG.proxy_server_url is not None: + headers = {"Content-Type": "application/json"} + payloads = {"input": "\n".join(msgs)} + response = requests.post( + CFG.proxy_server_url, headers=headers, json=payloads, stream=False + ) + if response.ok: + yield response.text + else: + yield f"bard proxy url request failed!, response = {str(response)}" else: - yield f"bard response error: {str(response)}" + response = bardapi.core.Bard(CFG.bard_proxy_api_key).get_answer("\n".join(msgs)) + + if response is not None and response.get("content") is not None: + yield str(response["content"]) + else: + yield f"bard response error: {str(response)}"