mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
fix: use openai package instead of requests
This commit is contained in:
parent
9979b6aa80
commit
1c56fca566
@ -2,10 +2,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
from typing import List
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
import openai
|
||||
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
|
||||
def chatgpt_generate_stream(
|
||||
@ -17,16 +20,11 @@ def chatgpt_generate_stream(
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_server_url = model_params.proxy_server_url
|
||||
openai.api_key = openai_key = os.getenv("OPENAI_API_KEY") or proxy_api_key
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
if not proxyllm_backend:
|
||||
proxyllm_backend = "gpt-3.5-turbo"
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer " + proxy_api_key,
|
||||
"Token": proxy_api_key,
|
||||
}
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
@ -52,28 +50,17 @@ def chatgpt_generate_stream(
|
||||
|
||||
payloads = {
|
||||
"model": proxyllm_backend, # just for test, remove this later
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True,
|
||||
}
|
||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||
|
||||
res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True)
|
||||
|
||||
print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}")
|
||||
print(f"Send request to real model {proxyllm_backend}")
|
||||
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
if not line.startswith(b"data: "):
|
||||
error_message = line.decode("utf-8")
|
||||
yield error_message
|
||||
else:
|
||||
json_data = line.split(b": ", 1)[1]
|
||||
decoded_line = json_data.decode("utf-8")
|
||||
if decoded_line.lower() != "[DONE]".lower():
|
||||
obj = json.loads(json_data)
|
||||
if obj["choices"][0]["delta"].get("content") is not None:
|
||||
content = obj["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
||||
for r in res:
|
||||
if r["choices"][0]["delta"].get("content") is not None:
|
||||
content = r["choices"][0]["delta"]["content"]
|
||||
text += content
|
||||
yield text
|
||||
|
Loading…
Reference in New Issue
Block a user