mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
Update chatgpt.py
This commit is contained in:
@@ -56,14 +56,15 @@ def _initialize_openai(params: ProxyModelParameters):
|
||||
|
||||
return openai_params
|
||||
|
||||
|
||||
def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai"
|
||||
)
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai"
|
||||
)
|
||||
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
@@ -89,11 +90,7 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
||||
"proxies": params.http_proxy,
|
||||
}
|
||||
|
||||
return openai_params,api_type,api_version
|
||||
|
||||
|
||||
|
||||
|
||||
return openai_params, api_type, api_version
|
||||
|
||||
|
||||
def _build_request(model: ProxyModel, params):
|
||||
@@ -102,7 +99,6 @@ def _build_request(model: ProxyModel, params):
|
||||
model_params = model.get_params()
|
||||
logger.info(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
@@ -134,7 +130,7 @@ def _build_request(model: ProxyModel, params):
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
openai_params,api_type,api_version = _initialize_openai_v1(model_params)
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||
payloads["model"] = proxyllm_backend
|
||||
else:
|
||||
@@ -158,20 +154,24 @@ def chatgpt_generate_stream(
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params,api_type,api_version = _initialize_openai_v1(model_params)
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AzureOpenAI
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key = openai_params["api_key"],
|
||||
api_version = api_version,
|
||||
azure_endpoint = openai_params["base_url"] # Your Azure OpenAI resource's endpoint value.
|
||||
)
|
||||
api_key=openai_params["api_key"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=openai_params[
|
||||
"base_url"
|
||||
], # Your Azure OpenAI resource's endpoint value.
|
||||
)
|
||||
else:
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(**openai_params)
|
||||
print("openai_params",openai_params)
|
||||
print("payloads",payloads)
|
||||
print("openai_params", openai_params)
|
||||
print("payloads", payloads)
|
||||
res = client.chat.completions.create(messages=history, **payloads)
|
||||
print(res)
|
||||
text = ""
|
||||
@@ -182,14 +182,14 @@ def chatgpt_generate_stream(
|
||||
yield text
|
||||
|
||||
else:
|
||||
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||
|
||||
text = ""
|
||||
print("res",res)
|
||||
print("res", res)
|
||||
for r in res:
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
@@ -202,22 +202,29 @@ async def async_chatgpt_generate_stream(
|
||||
):
|
||||
if metadata.version("openai") >= "1.0.0":
|
||||
model_params = model.get_params()
|
||||
openai_params,api_type,api_version = _initialize_openai_v1(model_params)
|
||||
openai_params, api_type, api_version = _initialize_openai_v1(model_params)
|
||||
history, payloads = _build_request(model, params)
|
||||
if api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key = openai_params["api_key"],
|
||||
end_point = openai_params["base_url"],
|
||||
api_version = api_version,
|
||||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # Your Azure OpenAI resource's endpoint value.
|
||||
)
|
||||
api_key=openai_params["api_key"],
|
||||
end_point=openai_params["base_url"],
|
||||
api_version=api_version,
|
||||
azure_endpoint=os.getenv(
|
||||
"AZURE_OPENAI_ENDPOINT"
|
||||
), # Your Azure OpenAI resource's endpoint value.
|
||||
)
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(**openai_params)
|
||||
res = await client.chat.completions.create(messages=history, **payloads).model_dump()
|
||||
res = await client.chat.completions.create(
|
||||
messages=history, **payloads
|
||||
).model_dump()
|
||||
else:
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||
|
Reference in New Issue
Block a user