Merge remote-tracking branch 'origin/main' into feat_rag_graph

This commit is contained in:
aries_ckt 2023-11-17 11:40:42 +08:00
commit 870aeb02b4

View File

@ -8,6 +8,7 @@ import importlib.metadata as metadata
from pilot.model.proxy.llms.proxy_model import ProxyModel from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -82,15 +83,12 @@ def _initialize_openai_v1(params: ProxyModelParameters):
# Adapt previous proxy_server_url configuration # Adapt previous proxy_server_url configuration
base_url = params.proxy_server_url.split("/chat/completions")[0] base_url = params.proxy_server_url.split("/chat/completions")[0]
if params.http_proxy: proxies = params.http_proxy
openai.proxies = params.http_proxy
openai_params = { openai_params = {
"api_key": api_key, "api_key": api_key,
"base_url": base_url, "base_url": base_url,
"proxies": params.http_proxy,
} }
return openai_params, api_type, api_version, proxies
return openai_params, api_type, api_version
def _build_request(model: ProxyModel, params): def _build_request(model: ProxyModel, params):
@ -130,7 +128,9 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = model_params.proxyllm_backend proxyllm_backend = model_params.proxyllm_backend
if metadata.version("openai") >= "1.0.0": if metadata.version("openai") >= "1.0.0":
openai_params, api_type, api_version = _initialize_openai_v1(model_params) openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend payloads["model"] = proxyllm_backend
else: else:
@ -154,7 +154,9 @@ def chatgpt_generate_stream(
): ):
if metadata.version("openai") >= "1.0.0": if metadata.version("openai") >= "1.0.0":
model_params = model.get_params() model_params = model.get_params()
openai_params, api_type, api_version = _initialize_openai_v1(model_params) openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
if api_type == "azure": if api_type == "azure":
from openai import AzureOpenAI from openai import AzureOpenAI
@ -162,14 +164,13 @@ def chatgpt_generate_stream(
client = AzureOpenAI( client = AzureOpenAI(
api_key=openai_params["api_key"], api_key=openai_params["api_key"],
api_version=api_version, api_version=api_version,
azure_endpoint=openai_params[ azure_endpoint=openai_params["base_url"],
"base_url" http_client=httpx.Client(proxies=proxies),
], # Your Azure OpenAI resource's endpoint value.
) )
else: else:
from openai import OpenAI from openai import OpenAI
client = OpenAI(**openai_params) client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
res = client.chat.completions.create(messages=history, **payloads) res = client.chat.completions.create(messages=history, **payloads)
text = "" text = ""
for r in res: for r in res:
@ -186,7 +187,6 @@ def chatgpt_generate_stream(
res = openai.ChatCompletion.create(messages=history, **payloads) res = openai.ChatCompletion.create(messages=history, **payloads)
text = "" text = ""
print("res", res)
for r in res: for r in res:
if r["choices"][0]["delta"].get("content") is not None: if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"] content = r["choices"][0]["delta"]["content"]
@ -199,7 +199,9 @@ async def async_chatgpt_generate_stream(
): ):
if metadata.version("openai") >= "1.0.0": if metadata.version("openai") >= "1.0.0":
model_params = model.get_params() model_params = model.get_params()
openai_params, api_type, api_version = _initialize_openai_v1(model_params) openai_params, api_type, api_version, proxies = _initialize_openai_v1(
model_params
)
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
if api_type == "azure": if api_type == "azure":
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
@ -207,14 +209,15 @@ async def async_chatgpt_generate_stream(
client = AsyncAzureOpenAI( client = AsyncAzureOpenAI(
api_key=openai_params["api_key"], api_key=openai_params["api_key"],
api_version=api_version, api_version=api_version,
azure_endpoint=openai_params[ azure_endpoint=openai_params["base_url"],
"base_url" http_client=httpx.AsyncClient(proxies=proxies),
], # Your Azure OpenAI resource's endpoint value.
) )
else: else:
from openai import AsyncOpenAI from openai import AsyncOpenAI
client = AsyncOpenAI(**openai_params) client = AsyncOpenAI(
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
)
res = await client.chat.completions.create(messages=history, **payloads) res = await client.chat.completions.create(messages=history, **payloads)
text = "" text = ""