Update chatgpt.py

This commit is contained in:
luchun
2023-11-16 22:39:18 +08:00
committed by GitHub
parent e198cd34ae
commit 1b615aedd2

View File

@@ -56,6 +56,7 @@ def _initialize_openai(params: ProxyModelParameters):
return openai_params return openai_params
def _initialize_openai_v1(params: ProxyModelParameters): def _initialize_openai_v1(params: ProxyModelParameters):
try: try:
from openai import OpenAI from openai import OpenAI
@@ -92,17 +93,12 @@ def _initialize_openai_v1(params: ProxyModelParameters):
return openai_params, api_type, api_version return openai_params, api_type, api_version
def _build_request(model: ProxyModel, params): def _build_request(model: ProxyModel, params):
history = [] history = []
model_params = model.get_params() model_params = model.get_params()
logger.info(f"Model: {model}, model_params: {model_params}") logger.info(f"Model: {model}, model_params: {model_params}")
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
# Add history conversation # Add history conversation
for message in messages: for message in messages:
@@ -162,13 +158,17 @@ def chatgpt_generate_stream(
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
if api_type == "azure": if api_type == "azure":
from openai import AzureOpenAI from openai import AzureOpenAI
client = AzureOpenAI( client = AzureOpenAI(
api_key=openai_params["api_key"], api_key=openai_params["api_key"],
api_version=api_version, api_version=api_version,
azure_endpoint = openai_params["base_url"] # Your Azure OpenAI resource's endpoint value. azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
) )
else: else:
from openai import OpenAI from openai import OpenAI
client = OpenAI(**openai_params) client = OpenAI(**openai_params)
print("openai_params", openai_params) print("openai_params", openai_params)
print("payloads", payloads) print("payloads", payloads)
@@ -182,8 +182,8 @@ def chatgpt_generate_stream(
yield text yield text
else: else:
import openai import openai
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
res = openai.ChatCompletion.create(messages=history, **payloads) res = openai.ChatCompletion.create(messages=history, **payloads)
@@ -206,18 +206,25 @@ async def async_chatgpt_generate_stream(
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
if api_type == "azure": if api_type == "azure":
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI( client = AsyncAzureOpenAI(
api_key=openai_params["api_key"], api_key=openai_params["api_key"],
end_point=openai_params["base_url"], end_point=openai_params["base_url"],
api_version=api_version, api_version=api_version,
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # Your Azure OpenAI resource's endpoint value. azure_endpoint=os.getenv(
"AZURE_OPENAI_ENDPOINT"
), # Your Azure OpenAI resource's endpoint value.
) )
else: else:
from openai import AsyncOpenAI from openai import AsyncOpenAI
client = AsyncOpenAI(**openai_params) client = AsyncOpenAI(**openai_params)
res = await client.chat.completions.create(messages=history, **payloads).model_dump() res = await client.chat.completions.create(
messages=history, **payloads
).model_dump()
else: else:
import openai import openai
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
res = await openai.ChatCompletion.acreate(messages=history, **payloads) res = await openai.ChatCompletion.acreate(messages=history, **payloads)