mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 06:53:12 +00:00
feat(web): Add incremental response to streaming response for /v1/chat/completion request
This commit is contained in:
parent
896af4e16f
commit
461179ee6f
@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import (
|
|||||||
ConversationVo,
|
ConversationVo,
|
||||||
MessageVo,
|
MessageVo,
|
||||||
ChatSceneVo,
|
ChatSceneVo,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
DeltaMessage,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
)
|
)
|
||||||
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_generator(chat),
|
stream_generator(chat, dialogue.incremental, dialogue.model_name),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
)
|
)
|
||||||
@ -421,19 +424,48 @@ async def no_stream_generator(chat):
|
|||||||
yield f"data: {msg}\n\n"
|
yield f"data: {msg}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(chat):
|
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||||
|
"""Generate streaming responses
|
||||||
|
|
||||||
|
Our goal is to generate an openai-compatible streaming responses.
|
||||||
|
Currently, the incremental response is compatible, and the full response will be transformed in the future.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (BaseChat): Chat instance.
|
||||||
|
incremental (bool): Used to control whether the content is returned incrementally or in full each time.
|
||||||
|
model_name (str): The model name
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
_type_: streaming responses
|
||||||
|
"""
|
||||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||||
|
|
||||||
|
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||||
|
previous_response = ""
|
||||||
async for chunk in chat.stream_call():
|
async for chunk in chat.stream_call():
|
||||||
if chunk:
|
if chunk:
|
||||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||||
chunk, chat.skip_echo_len
|
chunk, chat.skip_echo_len
|
||||||
)
|
)
|
||||||
|
msg = msg.replace("\ufffd", "")
|
||||||
|
if incremental:
|
||||||
|
incremental_output = msg[len(previous_response) :]
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=stream_id, choices=[choice_data], model=model_name
|
||||||
|
)
|
||||||
|
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||||
|
else:
|
||||||
|
# TODO generate an openai-compatible streaming responses
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
yield f"data:{msg}\n\n"
|
yield f"data:{msg}\n\n"
|
||||||
|
previous_response = msg
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
|
if incremental:
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
chat.current_message.add_view_message(msg)
|
chat.current_message.add_view_message(msg)
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Generic, Any
|
from typing import TypeVar, Generic, Any, Optional, Literal, List
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -59,6 +61,11 @@ class ConversationVo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
model_name: str = None
|
model_name: str = None
|
||||||
|
|
||||||
|
"""Used to control whether the content is returned incrementally or in full each time.
|
||||||
|
If this parameter is not provided, the default is full return.
|
||||||
|
"""
|
||||||
|
incremental: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -83,3 +90,21 @@ class MessageVo(BaseModel):
|
|||||||
model_name
|
model_name
|
||||||
"""
|
"""
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
Loading…
Reference in New Issue
Block a user