diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 5cd173e70..d77d29f9d 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -51,8 +51,47 @@ class Config(metaclass=Singleton): if self.bard_proxy_api_key: os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key - self.tongyi_api_key = os.getenv("TONGYI_PROXY_API_KEY") + # tongyi + self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY") + if self.tongyi_proxy_api_key: + os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key + + # zhipu + self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY") + if self.zhipu_proxy_api_key: + os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key + os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv("ZHIPU_MODEL_VERSION") + # wenxin + self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY") + self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY") + self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION") + if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret: + os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key + os.environ["wenxin_proxyllm_proxy_api_secret"] = self.wenxin_proxy_api_secret + os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version + + # xunfei spark + self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") + self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY") + self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET") + self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID") + if self.spark_proxy_api_key and self.spark_proxy_api_secret: + os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key + os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret + os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version + os.environ["spark_proxyllm_proxy_app_id"] = self.spark_proxy_api_appid + + # baichuan proxy + self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") + self.bc_proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET") + self.bc_model_version = os.getenv("BAICHUN_MODEL_NAME") + if self.bc_proxy_api_key and self.bc_proxy_api_secret: + os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key + os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret + os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version + + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 2a3df1835..271c6bd38 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -285,6 +285,19 @@ class ProxyModelParameters(BaseModelParameters): proxy_api_key: str = field( metadata={"tags": "privacy", "help": "The api key of current proxy LLM"}, ) + + proxy_app_id: Optional[str] = field( + default=None, + metadata={ + "help": "Appid for visitor proxy" + }, + ) + + proxy_api_secret: Optional[str] = field( + default=None, + metadata={"tags": "privacy", "help": "The api secret of current proxy LLM"}, + ) + http_proxy: Optional[str] = field( default=os.environ.get("http_proxy") or os.environ.get("https_proxy"), metadata={"help": "The http or https proxy to use openai"}, diff --git a/pilot/model/proxy/llms/baichuan.py b/pilot/model/proxy/llms/baichuan.py index 436002cc6..6dd5cacad 100644 --- a/pilot/model/proxy/llms/baichuan.py +++ b/pilot/model/proxy/llms/baichuan.py @@ -27,9 +27,9 @@ def baichuan_generate_stream( model_params = model.get_params() url = "https://api.baichuan-ai.com/v1/stream/chat" - model_name = os.getenv("BAICHUN_MODEL_NAME") or BAICHUAN_DEFAULT_MODEL - proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") - proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET") + model_name = model_params.proxyllm_backend or BAICHUAN_DEFAULT_MODEL + proxy_api_key = model_params.proxy_api_key + proxy_api_secret = model_params.proxy_api_secret history = [] diff --git a/pilot/model/proxy/llms/spark.py b/pilot/model/proxy/llms/spark.py index 47b665b3c..2a6a1579a 100644 --- a/pilot/model/proxy/llms/spark.py +++ b/pilot/model/proxy/llms/spark.py @@ -19,10 +19,10 @@ def spark_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): model_params = model.get_params() - proxy_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") or SPARK_DEFAULT_API_VERSION - proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY") - proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET") - proxy_app_id = os.getenv("XUNFEI_SPARK_APPID") + proxy_api_version = model_params.proxyllm_backend or SPARK_DEFAULT_API_VERSION + proxy_api_key = model_params.proxy_api_key + proxy_api_secret = model_params.proxy_api_secret + proxy_app_id = model_params.proxy_app_id if proxy_api_version == SPARK_DEFAULT_API_VERSION: url = "ws://spark-api.xf-yun.com/v2.1/chat" diff --git a/pilot/model/proxy/llms/tongyi.py b/pilot/model/proxy/llms/tongyi.py index 631424ae7..34aacff15 100644 --- a/pilot/model/proxy/llms/tongyi.py +++ b/pilot/model/proxy/llms/tongyi.py @@ -15,8 +15,8 @@ def tongyi_generate_stream( model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") - # proxy_api_key = model_params.proxy_api_key # // TODO Set this according env - dashscope.api_key = os.getenv("TONGYI_PROXY_API_KEY") + proxy_api_key = model_params.proxy_api_key + dashscope.api_key = proxy_api_key proxyllm_backend = model_params.proxyllm_backend diff --git a/pilot/model/proxy/llms/wenxin.py b/pilot/model/proxy/llms/wenxin.py index 0053005f5..74b235a46 100644 --- a/pilot/model/proxy/llms/wenxin.py +++ b/pilot/model/proxy/llms/wenxin.py @@ -29,13 +29,13 @@ def wenxin_generate_stream( } model_params = model.get_params() - model_name = os.getenv("WEN_XIN_MODEL_VERSION") + model_name = model_params.proxyllm_backend model_version = MODEL_VERSION.get(model_name) if not model_version: yield f"Unsupport model version {model_name}" - proxy_api_key = os.getenv("WEN_XIN_API_KEY") - proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY") + proxy_api_key = model_params.proxy_api_key + proxy_api_secret = model_params.proxy_api_secret access_token = _build_access_token(proxy_api_key, proxy_api_secret) headers = { diff --git a/pilot/model/proxy/llms/zhipu.py b/pilot/model/proxy/llms/zhipu.py index 5c97393ca..46fce50a0 100644 --- a/pilot/model/proxy/llms/zhipu.py +++ b/pilot/model/proxy/llms/zhipu.py @@ -14,6 +14,7 @@ def zhipu_generate_stream( model_params = model.get_params() print(f"Model: {model}, model_params: {model_params}") + # TODO proxy model use unified config? proxy_api_key = model_params.proxy_api_key proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend diff --git a/pilot/vector_store/pgvector_store.py b/pilot/vector_store/pgvector_store.py index 8155dd61c..98ce4a027 100644 --- a/pilot/vector_store/pgvector_store.py +++ b/pilot/vector_store/pgvector_store.py @@ -1,9 +1,12 @@ from typing import Any import logging from pilot.vector_store.base import VectorStoreBase +from pilot.configs.config import Config logger = logging.getLogger(__name__) +CFG = Config() + class PGVectorStore(VectorStoreBase): """`Postgres.PGVector` vector store.