mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
fix(Model): Compatible with openai 1.x.x, compatible with AzureOpeai (#804)
Update chatgpt.py
This commit is contained in:
commit
1ad09c896f
@ -4,7 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
import logging
|
import logging
|
||||||
|
import importlib.metadata as metadata
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
from pilot.model.parameter import ProxyModelParameters
|
from pilot.model.parameter import ProxyModelParameters
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
@ -57,14 +57,48 @@ def _initialize_openai(params: ProxyModelParameters):
|
|||||||
return openai_params
|
return openai_params
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_openai_v1(params: ProxyModelParameters):
|
||||||
|
try:
|
||||||
|
from openai import OpenAI
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import python package: openai "
|
||||||
|
"Please install openai by command `pip install openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||||
|
|
||||||
|
base_url = params.proxy_api_base or os.getenv(
|
||||||
|
"OPENAI_API_TYPE",
|
||||||
|
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||||
|
)
|
||||||
|
api_key = params.proxy_api_key or os.getenv(
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||||
|
)
|
||||||
|
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
|
||||||
|
|
||||||
|
if not base_url and params.proxy_server_url:
|
||||||
|
# Adapt previous proxy_server_url configuration
|
||||||
|
base_url = params.proxy_server_url.split("/chat/completions")[0]
|
||||||
|
|
||||||
|
if params.http_proxy:
|
||||||
|
openai.proxies = params.http_proxy
|
||||||
|
openai_params = {
|
||||||
|
"api_key": api_key,
|
||||||
|
"base_url": base_url,
|
||||||
|
"proxies": params.http_proxy,
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai_params, api_type, api_version
|
||||||
|
|
||||||
|
|
||||||
def _build_request(model: ProxyModel, params):
|
def _build_request(model: ProxyModel, params):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
logger.info(f"Model: {model}, model_params: {model_params}")
|
logger.info(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
openai_params = _initialize_openai(model_params)
|
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -95,13 +129,19 @@ def _build_request(model: ProxyModel, params):
|
|||||||
}
|
}
|
||||||
proxyllm_backend = model_params.proxyllm_backend
|
proxyllm_backend = model_params.proxyllm_backend
|
||||||
|
|
||||||
if openai_params["api_type"] == "azure":
|
if metadata.version("openai") >= "1.0.0":
|
||||||
# engine = "deployment_name".
|
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||||
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
|
|
||||||
payloads["engine"] = proxyllm_backend
|
|
||||||
else:
|
|
||||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||||
payloads["model"] = proxyllm_backend
|
payloads["model"] = proxyllm_backend
|
||||||
|
else:
|
||||||
|
openai_params = _initialize_openai(model_params)
|
||||||
|
if openai_params["api_type"] == "azure":
|
||||||
|
# engine = "deployment_name".
|
||||||
|
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
|
||||||
|
payloads["engine"] = proxyllm_backend
|
||||||
|
else:
|
||||||
|
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||||
|
payloads["model"] = proxyllm_backend
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
|
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
|
||||||
@ -112,32 +152,87 @@ def _build_request(model: ProxyModel, params):
|
|||||||
def chatgpt_generate_stream(
|
def chatgpt_generate_stream(
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
):
|
):
|
||||||
import openai
|
if metadata.version("openai") >= "1.0.0":
|
||||||
|
model_params = model.get_params()
|
||||||
|
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||||
|
history, payloads = _build_request(model, params)
|
||||||
|
if api_type == "azure":
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
history, payloads = _build_request(model, params)
|
client = AzureOpenAI(
|
||||||
|
api_key=openai_params["api_key"],
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=openai_params[
|
||||||
|
"base_url"
|
||||||
|
], # Your Azure OpenAI resource's endpoint value.
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
client = OpenAI(**openai_params)
|
||||||
|
res = client.chat.completions.create(messages=history, **payloads)
|
||||||
|
text = ""
|
||||||
|
for r in res:
|
||||||
|
if r.choices[0].delta.content is not None:
|
||||||
|
content = r.choices[0].delta.content
|
||||||
|
text += content
|
||||||
|
yield text
|
||||||
|
|
||||||
text = ""
|
else:
|
||||||
for r in res:
|
import openai
|
||||||
if r["choices"][0]["delta"].get("content") is not None:
|
|
||||||
content = r["choices"][0]["delta"]["content"]
|
history, payloads = _build_request(model, params)
|
||||||
text += content
|
|
||||||
yield text
|
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
print("res", res)
|
||||||
|
for r in res:
|
||||||
|
if r["choices"][0]["delta"].get("content") is not None:
|
||||||
|
content = r["choices"][0]["delta"]["content"]
|
||||||
|
text += content
|
||||||
|
yield text
|
||||||
|
|
||||||
|
|
||||||
async def async_chatgpt_generate_stream(
|
async def async_chatgpt_generate_stream(
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
):
|
):
|
||||||
import openai
|
if metadata.version("openai") >= "1.0.0":
|
||||||
|
model_params = model.get_params()
|
||||||
|
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||||
|
history, payloads = _build_request(model, params)
|
||||||
|
if api_type == "azure":
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
history, payloads = _build_request(model, params)
|
client = AsyncAzureOpenAI(
|
||||||
|
api_key=openai_params["api_key"],
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=openai_params[
|
||||||
|
"base_url"
|
||||||
|
], # Your Azure OpenAI resource's endpoint value.
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
client = AsyncOpenAI(**openai_params)
|
||||||
|
|
||||||
text = ""
|
res = await client.chat.completions.create(messages=history, **payloads)
|
||||||
async for r in res:
|
text = ""
|
||||||
if r["choices"][0]["delta"].get("content") is not None:
|
for r in res:
|
||||||
content = r["choices"][0]["delta"]["content"]
|
if r.choices[0].delta.content is not None:
|
||||||
text += content
|
content = r.choices[0].delta.content
|
||||||
yield text
|
text += content
|
||||||
|
yield text
|
||||||
|
else:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
history, payloads = _build_request(model, params)
|
||||||
|
|
||||||
|
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
async for r in res:
|
||||||
|
if r["choices"][0]["delta"].get("content") is not None:
|
||||||
|
content = r["choices"][0]["delta"]["content"]
|
||||||
|
text += content
|
||||||
|
yield text
|
||||||
|
Loading…
Reference in New Issue
Block a user