diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py index 21d598d5a..9e6d1a20a 100644 --- a/pilot/model/proxy/llms/chatgpt.py +++ b/pilot/model/proxy/llms/chatgpt.py @@ -8,6 +8,7 @@ import importlib.metadata as metadata from pilot.model.proxy.llms.proxy_model import ProxyModel from pilot.model.parameter import ProxyModelParameters from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +import httpx logger = logging.getLogger(__name__) @@ -82,15 +83,12 @@ def _initialize_openai_v1(params: ProxyModelParameters): # Adapt previous proxy_server_url configuration base_url = params.proxy_server_url.split("/chat/completions")[0] - if params.http_proxy: - openai.proxies = params.http_proxy + proxies = params.http_proxy openai_params = { "api_key": api_key, "base_url": base_url, - "proxies": params.http_proxy, } - - return openai_params, api_type, api_version + return openai_params, api_type, api_version, proxies def _build_request(model: ProxyModel, params): @@ -130,7 +128,9 @@ def _build_request(model: ProxyModel, params): proxyllm_backend = model_params.proxyllm_backend if metadata.version("openai") >= "1.0.0": - openai_params, api_type, api_version = _initialize_openai_v1(model_params) + openai_params, api_type, api_version, proxies = _initialize_openai_v1( + model_params + ) proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" payloads["model"] = proxyllm_backend else: @@ -154,7 +154,9 @@ def chatgpt_generate_stream( ): if metadata.version("openai") >= "1.0.0": model_params = model.get_params() - openai_params, api_type, api_version = _initialize_openai_v1(model_params) + openai_params, api_type, api_version, proxies = _initialize_openai_v1( + model_params + ) history, payloads = _build_request(model, params) if api_type == "azure": from openai import AzureOpenAI @@ -162,14 +164,13 @@ def chatgpt_generate_stream( client = AzureOpenAI( api_key=openai_params["api_key"], api_version=api_version, - azure_endpoint=openai_params[ - "base_url" - ], # Your Azure OpenAI resource's endpoint value. + azure_endpoint=openai_params["base_url"], + http_client=httpx.Client(proxies=proxies), ) else: from openai import OpenAI - client = OpenAI(**openai_params) + client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies)) res = client.chat.completions.create(messages=history, **payloads) text = "" for r in res: @@ -186,7 +187,6 @@ def chatgpt_generate_stream( res = openai.ChatCompletion.create(messages=history, **payloads) text = "" - print("res", res) for r in res: if r["choices"][0]["delta"].get("content") is not None: content = r["choices"][0]["delta"]["content"] @@ -199,7 +199,9 @@ async def async_chatgpt_generate_stream( ): if metadata.version("openai") >= "1.0.0": model_params = model.get_params() - openai_params, api_type, api_version = _initialize_openai_v1(model_params) + openai_params, api_type, api_version, proxies = _initialize_openai_v1( + model_params + ) history, payloads = _build_request(model, params) if api_type == "azure": from openai import AsyncAzureOpenAI @@ -207,14 +209,15 @@ async def async_chatgpt_generate_stream( client = AsyncAzureOpenAI( api_key=openai_params["api_key"], api_version=api_version, - azure_endpoint=openai_params[ - "base_url" - ], # Your Azure OpenAI resource's endpoint value. + azure_endpoint=openai_params["base_url"], + http_client=httpx.AsyncClient(proxies=proxies), ) else: from openai import AsyncOpenAI - client = AsyncOpenAI(**openai_params) + client = AsyncOpenAI( + **openai_params, http_client=httpx.AsyncClient(proxies=proxies) + ) res = await client.chat.completions.create(messages=history, **payloads) text = ""