diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py index f54ec36b3..9e6d1a20a 100644 --- a/pilot/model/proxy/llms/chatgpt.py +++ b/pilot/model/proxy/llms/chatgpt.py @@ -9,6 +9,7 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel from pilot.model.parameter import ProxyModelParameters from pilot.scene.base_message import ModelMessage, ModelMessageRoleType import httpx + logger = logging.getLogger(__name__) @@ -82,13 +83,12 @@ def _initialize_openai_v1(params: ProxyModelParameters): # Adapt previous proxy_server_url configuration base_url = params.proxy_server_url.split("/chat/completions")[0] - proxies = params.http_proxy openai_params = { "api_key": api_key, "base_url": base_url, } - return openai_params, api_type, api_version, proxies + return openai_params, api_type, api_version, proxies def _build_request(model: ProxyModel, params): @@ -128,7 +128,9 @@ def _build_request(model: ProxyModel, params): proxyllm_backend = model_params.proxyllm_backend if metadata.version("openai") >= "1.0.0": - openai_params, api_type, api_version, proxies = _initialize_openai_v1(model_params) + openai_params, api_type, api_version, proxies = _initialize_openai_v1( + model_params + ) proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" payloads["model"] = proxyllm_backend else: @@ -152,7 +154,9 @@ def chatgpt_generate_stream( ): if metadata.version("openai") >= "1.0.0": model_params = model.get_params() - openai_params, api_type, api_version, proxies = _initialize_openai_v1(model_params) + openai_params, api_type, api_version, proxies = _initialize_openai_v1( + model_params + ) history, payloads = _build_request(model, params) if api_type == "azure": from openai import AzureOpenAI @@ -160,15 +164,13 @@ def chatgpt_generate_stream( client = AzureOpenAI( api_key=openai_params["api_key"], api_version=api_version, - azure_endpoint=openai_params[ - "base_url" - ], - http_client=httpx.Client(proxies=proxies) + azure_endpoint=openai_params["base_url"], + http_client=httpx.Client(proxies=proxies), ) else: from openai import OpenAI - client = OpenAI(**openai_params,http_client=httpx.Client(proxies=proxies)) + client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies)) res = client.chat.completions.create(messages=history, **payloads) text = "" for r in res: @@ -197,7 +199,9 @@ async def async_chatgpt_generate_stream( ): if metadata.version("openai") >= "1.0.0": model_params = model.get_params() - openai_params, api_type, api_version,proxies = _initialize_openai_v1(model_params) + openai_params, api_type, api_version, proxies = _initialize_openai_v1( + model_params + ) history, payloads = _build_request(model, params) if api_type == "azure": from openai import AsyncAzureOpenAI @@ -205,15 +209,15 @@ async def async_chatgpt_generate_stream( client = AsyncAzureOpenAI( api_key=openai_params["api_key"], api_version=api_version, - azure_endpoint=openai_params[ - "base_url" - ], - http_client=httpx.AsyncClient(proxies=proxies) + azure_endpoint=openai_params["base_url"], + http_client=httpx.AsyncClient(proxies=proxies), ) else: from openai import AsyncOpenAI - client = AsyncOpenAI(**openai_params,http_client=httpx.AsyncClient(proxies=proxies)) + client = AsyncOpenAI( + **openai_params, http_client=httpx.AsyncClient(proxies=proxies) + ) res = await client.chat.completions.create(messages=history, **payloads) text = ""