mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
Update chatgpt.py
This commit is contained in:
parent
5eccf73f24
commit
ddabd019e3
@ -9,6 +9,7 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
|
|||||||
from pilot.model.parameter import ProxyModelParameters
|
from pilot.model.parameter import ProxyModelParameters
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +83,6 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
|||||||
# Adapt previous proxy_server_url configuration
|
# Adapt previous proxy_server_url configuration
|
||||||
base_url = params.proxy_server_url.split("/chat/completions")[0]
|
base_url = params.proxy_server_url.split("/chat/completions")[0]
|
||||||
|
|
||||||
|
|
||||||
proxies = params.http_proxy
|
proxies = params.http_proxy
|
||||||
openai_params = {
|
openai_params = {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
@ -128,7 +128,9 @@ def _build_request(model: ProxyModel, params):
|
|||||||
proxyllm_backend = model_params.proxyllm_backend
|
proxyllm_backend = model_params.proxyllm_backend
|
||||||
|
|
||||||
if metadata.version("openai") >= "1.0.0":
|
if metadata.version("openai") >= "1.0.0":
|
||||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(model_params)
|
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||||
|
model_params
|
||||||
|
)
|
||||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||||
payloads["model"] = proxyllm_backend
|
payloads["model"] = proxyllm_backend
|
||||||
else:
|
else:
|
||||||
@ -152,7 +154,9 @@ def chatgpt_generate_stream(
|
|||||||
):
|
):
|
||||||
if metadata.version("openai") >= "1.0.0":
|
if metadata.version("openai") >= "1.0.0":
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
openai_params, api_type, api_version, proxies = _initialize_openai_v1(model_params)
|
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||||
|
model_params
|
||||||
|
)
|
||||||
history, payloads = _build_request(model, params)
|
history, payloads = _build_request(model, params)
|
||||||
if api_type == "azure":
|
if api_type == "azure":
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
@ -160,10 +164,8 @@ def chatgpt_generate_stream(
|
|||||||
client = AzureOpenAI(
|
client = AzureOpenAI(
|
||||||
api_key=openai_params["api_key"],
|
api_key=openai_params["api_key"],
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_endpoint=openai_params[
|
azure_endpoint=openai_params["base_url"],
|
||||||
"base_url"
|
http_client=httpx.Client(proxies=proxies),
|
||||||
],
|
|
||||||
http_client=httpx.Client(proxies=proxies)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@ -197,7 +199,9 @@ async def async_chatgpt_generate_stream(
|
|||||||
):
|
):
|
||||||
if metadata.version("openai") >= "1.0.0":
|
if metadata.version("openai") >= "1.0.0":
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
openai_params, api_type, api_version,proxies = _initialize_openai_v1(model_params)
|
openai_params, api_type, api_version, proxies = _initialize_openai_v1(
|
||||||
|
model_params
|
||||||
|
)
|
||||||
history, payloads = _build_request(model, params)
|
history, payloads = _build_request(model, params)
|
||||||
if api_type == "azure":
|
if api_type == "azure":
|
||||||
from openai import AsyncAzureOpenAI
|
from openai import AsyncAzureOpenAI
|
||||||
@ -205,15 +209,15 @@ async def async_chatgpt_generate_stream(
|
|||||||
client = AsyncAzureOpenAI(
|
client = AsyncAzureOpenAI(
|
||||||
api_key=openai_params["api_key"],
|
api_key=openai_params["api_key"],
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_endpoint=openai_params[
|
azure_endpoint=openai_params["base_url"],
|
||||||
"base_url"
|
http_client=httpx.AsyncClient(proxies=proxies),
|
||||||
],
|
|
||||||
http_client=httpx.AsyncClient(proxies=proxies)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
client = AsyncOpenAI(**openai_params,http_client=httpx.AsyncClient(proxies=proxies))
|
client = AsyncOpenAI(
|
||||||
|
**openai_params, http_client=httpx.AsyncClient(proxies=proxies)
|
||||||
|
)
|
||||||
|
|
||||||
res = await client.chat.completions.create(messages=history, **payloads)
|
res = await client.chat.completions.create(messages=history, **payloads)
|
||||||
text = ""
|
text = ""
|
||||||
|
Loading…
Reference in New Issue
Block a user