mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
Merge remote-tracking branch 'origin/main' into feat_rag_graph
This commit is contained in:
commit
870aeb02b4
@ -8,6 +8,7 @@ import importlib.metadata as metadata
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
from pilot.model.parameter import ProxyModelParameters
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -82,15 +83,12 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
# Adapt previous proxy_server_url configuration
|
||||
base_url = params.proxy_server_url.split("/chat/completions")[0]
|
||||
|
||||
if params.http_proxy:
|
||||
openai.proxies = params.http_proxy
|
||||
proxies = params.http_proxy
|
||||
openai_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
"proxies": params.http_proxy,
|
||||
}
|
||||
|
||||
return openai_params, api_type, api_version
|
||||
return openai_params, api_type, api_version, proxies
|
||||
|
||||
|
||||
def _build_request(model: ProxyModel, params):
|
||||
@ -130,7 +128,9 @@ def _build_request(model: ProxyModel, params):
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
else:
|
||||
@ -154,7 +154,9 @@ def chatgpt_generate_stream(
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AzureOpenAI
|
||||
@ -162,14 +164,13 @@ def chatgpt_generate_stream(
|
||||
client = AzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params[
|
||||
"base_url"
|
||||
], # Your Azure OpenAI resource's endpoint value.
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.Client(proxies=proxies),
|
||||
)
|
||||
else:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(**openai_params)
|
||||
client = OpenAI(**openai_params, http_client=httpx.Client(proxies=proxies))
|
||||
res = client.chat.completions.create(messages=history, **payloads)
|
||||
text = ""
|
||||
for r in res:
|
||||
@ -186,7 +187,6 @@ def chatgpt_generate_stream(
|
||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||
|
||||
text = ""
|
||||
print("res", res)
|
||||
for r in res:
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
@ -199,7 +199,9 @@ async def async_chatgpt_generate_stream(
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||
model_params
|
||||
)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
@ -207,14 +209,15 @@ async def async_chatgpt_generate_stream(
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params[
|
||||
"base_url"
|
||||
], # Your Azure OpenAI resource's endpoint value.
|
||||
azure_endpoint=openai_params["base_url"],
|
||||
http_client=httpx.AsyncClient(proxies=proxies),
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(**openai_params)
|
||||
client = AsyncOpenAI(
|
||||
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
|
||||
)
|
||||
|
||||
res = await client.chat.completions.create(messages=history, **payloads)
|
||||
text = ""
|
||||
|
Loading…
Reference in New Issue
Block a user