mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 10:00:17 +00:00
fix: use openai package instead of requests
This commit is contained in:
parent
9979b6aa80
commit
1c56fca566
@ -2,10 +2,13 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import requests
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
|
||||||
|
import openai
|
||||||
|
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
|
||||||
def chatgpt_generate_stream(
|
def chatgpt_generate_stream(
|
||||||
@ -17,16 +20,11 @@ def chatgpt_generate_stream(
|
|||||||
print(f"Model: {model}, model_params: {model_params}")
|
print(f"Model: {model}, model_params: {model_params}")
|
||||||
|
|
||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_server_url = model_params.proxy_server_url
|
openai.api_key = openai_key = os.getenv("OPENAI_API_KEY") or proxy_api_key
|
||||||
proxyllm_backend = model_params.proxyllm_backend
|
proxyllm_backend = model_params.proxyllm_backend
|
||||||
if not proxyllm_backend:
|
if not proxyllm_backend:
|
||||||
proxyllm_backend = "gpt-3.5-turbo"
|
proxyllm_backend = "gpt-3.5-turbo"
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": "Bearer " + proxy_api_key,
|
|
||||||
"Token": proxy_api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -52,28 +50,17 @@ def chatgpt_generate_stream(
|
|||||||
|
|
||||||
payloads = {
|
payloads = {
|
||||||
"model": proxyllm_backend, # just for test, remove this later
|
"model": proxyllm_backend, # just for test, remove this later
|
||||||
"messages": history,
|
|
||||||
"temperature": params.get("temperature"),
|
"temperature": params.get("temperature"),
|
||||||
"max_tokens": params.get("max_new_tokens"),
|
"max_tokens": params.get("max_new_tokens"),
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||||
|
|
||||||
res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True)
|
print(f"Send request to real model {proxyllm_backend}")
|
||||||
|
|
||||||
print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}")
|
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
for line in res.iter_lines():
|
for r in res:
|
||||||
if line:
|
if r["choices"][0]["delta"].get("content") is not None:
|
||||||
if not line.startswith(b"data: "):
|
content = r["choices"][0]["delta"]["content"]
|
||||||
error_message = line.decode("utf-8")
|
text += content
|
||||||
yield error_message
|
yield text
|
||||||
else:
|
|
||||||
json_data = line.split(b": ", 1)[1]
|
|
||||||
decoded_line = json_data.decode("utf-8")
|
|
||||||
if decoded_line.lower() != "[DONE]".lower():
|
|
||||||
obj = json.loads(json_data)
|
|
||||||
if obj["choices"][0]["delta"].get("content") is not None:
|
|
||||||
content = obj["choices"][0]["delta"]["content"]
|
|
||||||
text += content
|
|
||||||
yield text
|
|
||||||
|
Loading…
Reference in New Issue
Block a user