Update chatgpt.py

This commit is contained in:
luchun
2023-11-16 22:39:18 +08:00
committed by GitHub
parent e198cd34ae
commit 1b615aedd2

View File

@@ -56,6 +56,7 @@ def _initialize_openai(params: ProxyModelParameters):
return openai_params return openai_params
def _initialize_openai_v1(params: ProxyModelParameters): def _initialize_openai_v1(params: ProxyModelParameters):
try: try:
from openai import OpenAI from openai import OpenAI
@@ -89,11 +90,7 @@ def _initialize_openai_v1(params: ProxyModelParameters):
"proxies": params.http_proxy, "proxies": params.http_proxy,
} }
return openai_params,api_type,api_version return openai_params, api_type, api_version
def _build_request(model: ProxyModel, params): def _build_request(model: ProxyModel, params):
@@ -102,7 +99,6 @@ def _build_request(model: ProxyModel, params):
model_params = model.get_params() model_params = model.get_params()
logger.info(f"Model: {model}, model_params: {model_params}") logger.info(f"Model: {model}, model_params: {model_params}")
messages: List[ModelMessage] = params["messages"] messages: List[ModelMessage] = params["messages"]
# Add history conversation # Add history conversation
for message in messages: for message in messages:
@@ -134,7 +130,7 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = model_params.proxyllm_backend proxyllm_backend = model_params.proxyllm_backend
if metadata.version("openai") >= "1.0.0": if metadata.version("openai") >= "1.0.0":
openai_params,api_type,api_version = _initialize_openai_v1(model_params) openai_params, api_type, api_version = _initialize_openai_v1(model_params)
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo" proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend payloads["model"] = proxyllm_backend
else: else:
@@ -158,20 +154,24 @@ def chatgpt_generate_stream(
): ):
if metadata.version("openai") >= "1.0.0": if metadata.version("openai") >= "1.0.0":
model_params = model.get_params() model_params = model.get_params()
openai_params,api_type,api_version = _initialize_openai_v1(model_params) openai_params, api_type, api_version = _initialize_openai_v1(model_params)
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
if api_type == "azure": if api_type == "azure":
from openai import AzureOpenAI from openai import AzureOpenAI
client = AzureOpenAI( client = AzureOpenAI(
api_key = openai_params["api_key"], api_key=openai_params["api_key"],
api_version = api_version, api_version=api_version,
azure_endpoint = openai_params["base_url"] # Your Azure OpenAI resource's endpoint value. azure_endpoint=openai_params[
"base_url"
], # Your Azure OpenAI resource's endpoint value.
) )
else: else:
from openai import OpenAI from openai import OpenAI
client = OpenAI(**openai_params) client = OpenAI(**openai_params)
print("openai_params",openai_params) print("openai_params", openai_params)
print("payloads",payloads) print("payloads", payloads)
res = client.chat.completions.create(messages=history, **payloads) res = client.chat.completions.create(messages=history, **payloads)
print(res) print(res)
text = "" text = ""
@@ -182,14 +182,14 @@ def chatgpt_generate_stream(
yield text yield text
else: else:
import openai import openai
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
res = openai.ChatCompletion.create(messages=history, **payloads) res = openai.ChatCompletion.create(messages=history, **payloads)
text = "" text = ""
print("res",res) print("res", res)
for r in res: for r in res:
if r["choices"][0]["delta"].get("content") is not None: if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"] content = r["choices"][0]["delta"]["content"]
@@ -202,22 +202,29 @@ async def async_chatgpt_generate_stream(
): ):
if metadata.version("openai") >= "1.0.0": if metadata.version("openai") >= "1.0.0":
model_params = model.get_params() model_params = model.get_params()
openai_params,api_type,api_version = _initialize_openai_v1(model_params) openai_params, api_type, api_version = _initialize_openai_v1(model_params)
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
if api_type == "azure": if api_type == "azure":
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI( client = AsyncAzureOpenAI(
api_key = openai_params["api_key"], api_key=openai_params["api_key"],
end_point = openai_params["base_url"], end_point=openai_params["base_url"],
api_version = api_version, api_version=api_version,
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # Your Azure OpenAI resource's endpoint value. azure_endpoint=os.getenv(
"AZURE_OPENAI_ENDPOINT"
), # Your Azure OpenAI resource's endpoint value.
) )
else: else:
from openai import AsyncOpenAI from openai import AsyncOpenAI
client = AsyncOpenAI(**openai_params) client = AsyncOpenAI(**openai_params)
res = await client.chat.completions.create(messages=history, **payloads).model_dump() res = await client.chat.completions.create(
messages=history, **payloads
).model_dump()
else: else:
import openai import openai
history, payloads = _build_request(model, params) history, payloads = _build_request(model, params)
res = await openai.ChatCompletion.acreate(messages=history, **payloads) res = await openai.ChatCompletion.acreate(messages=history, **payloads)