mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
feat: use proxymodel params for all proxy model
This commit is contained in:
@@ -51,8 +51,47 @@ class Config(metaclass=Singleton):
|
|||||||
if self.bard_proxy_api_key:
|
if self.bard_proxy_api_key:
|
||||||
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
|
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
|
||||||
|
|
||||||
self.tongyi_api_key = os.getenv("TONGYI_PROXY_API_KEY")
|
# tongyi
|
||||||
|
self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY")
|
||||||
|
if self.tongyi_proxy_api_key:
|
||||||
|
os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key
|
||||||
|
|
||||||
|
# zhipu
|
||||||
|
self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY")
|
||||||
|
if self.zhipu_proxy_api_key:
|
||||||
|
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
|
||||||
|
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv("ZHIPU_MODEL_VERSION")
|
||||||
|
|
||||||
|
# wenxin
|
||||||
|
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
|
||||||
|
self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY")
|
||||||
|
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
|
||||||
|
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
|
||||||
|
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
|
||||||
|
os.environ["wenxin_proxyllm_proxy_api_secret"] = self.wenxin_proxy_api_secret
|
||||||
|
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
|
||||||
|
|
||||||
|
# xunfei spark
|
||||||
|
self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION")
|
||||||
|
self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
|
||||||
|
self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
|
||||||
|
self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID")
|
||||||
|
if self.spark_proxy_api_key and self.spark_proxy_api_secret:
|
||||||
|
os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key
|
||||||
|
os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret
|
||||||
|
os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version
|
||||||
|
os.environ["spark_proxyllm_proxy_app_id"] = self.spark_proxy_api_appid
|
||||||
|
|
||||||
|
# baichuan proxy
|
||||||
|
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
||||||
|
self.bc_proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
|
||||||
|
self.bc_model_version = os.getenv("BAICHUN_MODEL_NAME")
|
||||||
|
if self.bc_proxy_api_key and self.bc_proxy_api_secret:
|
||||||
|
os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key
|
||||||
|
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
|
||||||
|
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
|
||||||
|
|
||||||
|
|
||||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||||
|
|
||||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||||
|
@@ -285,6 +285,19 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
proxy_api_key: str = field(
|
proxy_api_key: str = field(
|
||||||
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
proxy_app_id: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Appid for visitor proxy"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_api_secret: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"tags": "privacy", "help": "The api secret of current proxy LLM"},
|
||||||
|
)
|
||||||
|
|
||||||
http_proxy: Optional[str] = field(
|
http_proxy: Optional[str] = field(
|
||||||
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
||||||
metadata={"help": "The http or https proxy to use openai"},
|
metadata={"help": "The http or https proxy to use openai"},
|
||||||
|
@@ -27,9 +27,9 @@ def baichuan_generate_stream(
|
|||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||||
|
|
||||||
model_name = os.getenv("BAICHUN_MODEL_NAME") or BAICHUAN_DEFAULT_MODEL
|
model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL
|
||||||
proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET")
|
proxy_api_secret = model_params.proxy_api_secret
|
||||||
|
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
|
@@ -19,10 +19,10 @@ def spark_generate_stream(
|
|||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
):
|
):
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
proxy_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") or SPARK_DEFAULT_API_VERSION
|
proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION
|
||||||
proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY")
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET")
|
proxy_api_secret = model_params.proxy_api_secret
|
||||||
proxy_app_id = os.getenv("XUNFEI_SPARK_APPID")
|
proxy_app_id = model_params.proxy_app_id
|
||||||
|
|
||||||
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
if proxy_api_version == SPARK_DEFAULT_API_VERSION:
|
||||||
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||||
|
@@ -15,8 +15,8 @@ def tongyi_generate_stream(
|
|||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
print(f"Model: {model}, model_params: {model_params}")
|
print(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
# proxy_api_key = model_params.proxy_api_key # // TODO Set this according env
|
proxy_api_key = model_params.proxy_api_key
|
||||||
dashscope.api_key = os.getenv("TONGYI_PROXY_API_KEY")
|
dashscope.api_key = proxy_api_key
|
||||||
|
|
||||||
|
|
||||||
proxyllm_backend = model_params.proxyllm_backend
|
proxyllm_backend = model_params.proxyllm_backend
|
||||||
|
@@ -29,13 +29,13 @@ def wenxin_generate_stream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
model_name = os.getenv("WEN_XIN_MODEL_VERSION")
|
model_name = model_params.proxyllm_backend
|
||||||
model_version = MODEL_VERSION.get(model_name)
|
model_version = MODEL_VERSION.get(model_name)
|
||||||
if not model_version:
|
if not model_version:
|
||||||
yield f"Unsupport model version {model_name}"
|
yield f"Unsupport model version {model_name}"
|
||||||
|
|
||||||
proxy_api_key = os.getenv("WEN_XIN_API_KEY")
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY")
|
proxy_api_secret = model_params.proxy_api_secret
|
||||||
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
@@ -14,6 +14,7 @@ def zhipu_generate_stream(
|
|||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
print(f"Model: {model}, model_params: {model_params}")
|
print(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
|
# TODO proxy model use unified config?
|
||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
|
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||||
|
|
||||||
|
@@ -1,9 +1,12 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import logging
|
import logging
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
class PGVectorStore(VectorStoreBase):
|
class PGVectorStore(VectorStoreBase):
|
||||||
"""`Postgres.PGVector` vector store.
|
"""`Postgres.PGVector` vector store.
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user