mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 13:29:35 +00:00
fix(model): Fix openai adapt previous proxy_server_url configuration and support azure openai model (#668)
This commit is contained in:
@@ -282,9 +282,30 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
|
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
proxy_api_key: str = field(
|
proxy_api_key: str = field(
|
||||||
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
proxy_api_base: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_api_type: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_api_version: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The api version of current proxy the current model"},
|
||||||
|
)
|
||||||
|
|
||||||
http_proxy: Optional[str] = field(
|
http_proxy: Optional[str] = field(
|
||||||
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
||||||
metadata={"help": "The http or https proxy to use openai"},
|
metadata={"help": "The http or https proxy to use openai"},
|
||||||
|
@@ -1,31 +1,63 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
from pilot.model.parameter import ProxyModelParameters
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def chatgpt_generate_stream(
|
|
||||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
def _initialize_openai(params: ProxyModelParameters):
|
||||||
):
|
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||||
|
|
||||||
|
api_base = params.proxy_api_base or os.getenv(
|
||||||
|
"OPENAI_API_TYPE",
|
||||||
|
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
|
||||||
|
)
|
||||||
|
api_key = params.proxy_api_key or os.getenv(
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
|
||||||
|
)
|
||||||
|
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")
|
||||||
|
|
||||||
|
if not api_base and params.proxy_server_url:
|
||||||
|
# Adapt previous proxy_server_url configuration
|
||||||
|
api_base = params.proxy_server_url.split("/chat/completions")[0]
|
||||||
|
if api_type:
|
||||||
|
openai.api_type = api_type
|
||||||
|
if api_base:
|
||||||
|
openai.api_base = api_base
|
||||||
|
if api_key:
|
||||||
|
openai.api_key = api_key
|
||||||
|
if api_version:
|
||||||
|
openai.api_version = api_version
|
||||||
|
if params.http_proxy:
|
||||||
|
openai.proxy = params.http_proxy
|
||||||
|
|
||||||
|
openai_params = {
|
||||||
|
"api_type": api_type,
|
||||||
|
"api_base": api_base,
|
||||||
|
"api_version": api_version,
|
||||||
|
"proxy": params.http_proxy,
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai_params
|
||||||
|
|
||||||
|
|
||||||
|
def _build_request(model: ProxyModel, params):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
print(f"Model: {model}, model_params: {model_params}")
|
logger.info(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
proxy_api_key = model_params.proxy_api_key
|
openai_params = _initialize_openai(model_params)
|
||||||
if model_params.http_proxy:
|
|
||||||
openai.proxy = model_params.http_proxy
|
|
||||||
openai.api_key = os.getenv("OPENAI_API_KEY") or proxy_api_key
|
|
||||||
proxyllm_backend = model_params.proxyllm_backend
|
|
||||||
if not proxyllm_backend:
|
|
||||||
proxyllm_backend = "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
@@ -51,14 +83,32 @@ def chatgpt_generate_stream(
|
|||||||
history.append(last_user_input)
|
history.append(last_user_input)
|
||||||
|
|
||||||
payloads = {
|
payloads = {
|
||||||
"model": proxyllm_backend, # just for test, remove this later
|
|
||||||
"temperature": params.get("temperature"),
|
"temperature": params.get("temperature"),
|
||||||
"max_tokens": params.get("max_new_tokens"),
|
"max_tokens": params.get("max_new_tokens"),
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
proxyllm_backend = model_params.proxyllm_backend
|
||||||
|
|
||||||
print(f"Send request to real model {proxyllm_backend}")
|
if openai_params["api_type"] == "azure":
|
||||||
|
# engine = "deployment_name".
|
||||||
|
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
|
||||||
|
payloads["engine"] = proxyllm_backend
|
||||||
|
else:
|
||||||
|
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||||
|
payloads["model"] = proxyllm_backend
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
|
||||||
|
)
|
||||||
|
return history, payloads
|
||||||
|
|
||||||
|
|
||||||
|
def chatgpt_generate_stream(
|
||||||
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
|
):
|
||||||
|
history, payloads = _build_request(model, params)
|
||||||
|
|
||||||
|
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
for r in res:
|
for r in res:
|
||||||
@@ -66,3 +116,18 @@ def chatgpt_generate_stream(
|
|||||||
content = r["choices"][0]["delta"]["content"]
|
content = r["choices"][0]["delta"]["content"]
|
||||||
text += content
|
text += content
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
|
|
||||||
|
async def async_chatgpt_generate_stream(
|
||||||
|
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||||
|
):
|
||||||
|
history, payloads = _build_request(model, params)
|
||||||
|
|
||||||
|
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
async for r in res:
|
||||||
|
if r["choices"][0]["delta"].get("content") is not None:
|
||||||
|
content = r["choices"][0]["delta"]["content"]
|
||||||
|
text += content
|
||||||
|
yield text
|
||||||
|
Reference in New Issue
Block a user