mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
Update chatgpt.py
This commit is contained in:
@@ -56,6 +56,7 @@ def _initialize_openai(params: ProxyModelParameters):
|
|||||||
|
|
||||||
return openai_params
|
return openai_params
|
||||||
|
|
||||||
|
|
||||||
def _initialize_openai_v1(params: ProxyModelParameters):
|
def _initialize_openai_v1(params: ProxyModelParameters):
|
||||||
try:
|
try:
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -92,17 +93,12 @@ def _initialize_openai_v1(params: ProxyModelParameters):
|
|||||||
return openai_params, api_type, api_version
|
return openai_params, api_type, api_version
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _build_request(model: ProxyModel, params):
|
def _build_request(model: ProxyModel, params):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
logger.info(f"Model: {model}, model_params: {model_params}")
|
logger.info(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -162,13 +158,17 @@ def chatgpt_generate_stream(
|
|||||||
history, payloads = _build_request(model, params)
|
history, payloads = _build_request(model, params)
|
||||||
if api_type == "azure":
|
if api_type == "azure":
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
client = AzureOpenAI(
|
client = AzureOpenAI(
|
||||||
api_key=openai_params["api_key"],
|
api_key=openai_params["api_key"],
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_endpoint = openai_params["base_url"] # Your Azure OpenAI resource's endpoint value.
|
azure_endpoint=openai_params[
|
||||||
|
"base_url"
|
||||||
|
], # Your Azure OpenAI resource's endpoint value.
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
client = OpenAI(**openai_params)
|
client = OpenAI(**openai_params)
|
||||||
print("openai_params", openai_params)
|
print("openai_params", openai_params)
|
||||||
print("payloads", payloads)
|
print("payloads", payloads)
|
||||||
@@ -182,8 +182,8 @@ def chatgpt_generate_stream(
|
|||||||
yield text
|
yield text
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
history, payloads = _build_request(model, params)
|
history, payloads = _build_request(model, params)
|
||||||
|
|
||||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||||
@@ -206,18 +206,25 @@ async def async_chatgpt_generate_stream(
|
|||||||
history, payloads = _build_request(model, params)
|
history, payloads = _build_request(model, params)
|
||||||
if api_type == "azure":
|
if api_type == "azure":
|
||||||
from openai import AsyncAzureOpenAI
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
client = AsyncAzureOpenAI(
|
client = AsyncAzureOpenAI(
|
||||||
api_key=openai_params["api_key"],
|
api_key=openai_params["api_key"],
|
||||||
end_point=openai_params["base_url"],
|
end_point=openai_params["base_url"],
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # Your Azure OpenAI resource's endpoint value.
|
azure_endpoint=os.getenv(
|
||||||
|
"AZURE_OPENAI_ENDPOINT"
|
||||||
|
), # Your Azure OpenAI resource's endpoint value.
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
client = AsyncOpenAI(**openai_params)
|
client = AsyncOpenAI(**openai_params)
|
||||||
res = await client.chat.completions.create(messages=history, **payloads).model_dump()
|
res = await client.chat.completions.create(
|
||||||
|
messages=history, **payloads
|
||||||
|
).model_dump()
|
||||||
else:
|
else:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
history, payloads = _build_request(model, params)
|
history, payloads = _build_request(model, params)
|
||||||
|
|
||||||
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||||
|
Reference in New Issue
Block a user