diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 7b198c49a..24bee6cdb 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -26,6 +26,9 @@ from pilot.openapi.api_view_model import ( ConversationVo, MessageVo, ChatSceneVo, + ChatCompletionResponseStreamChoice, + DeltaMessage, + ChatCompletionStreamResponse, ) from pilot.connections.db_conn_info import DBConfig, DbTypeInfo from pilot.configs.config import Config @@ -383,7 +386,7 @@ async def chat_completions(dialogue: ConversationVo = Body()): ) else: return StreamingResponse( - stream_generator(chat), + stream_generator(chat, dialogue.incremental, dialogue.model_name), headers=headers, media_type="text/plain", ) @@ -421,19 +424,48 @@ async def no_stream_generator(chat): yield f"data: {msg}\n\n" -async def stream_generator(chat): +async def stream_generator(chat, incremental: bool, model_name: str): + """Generate streaming responses + + Our goal is to generate an openai-compatible streaming responses. + Currently, the incremental response is compatible, and the full response will be transformed in the future. + + Args: + chat (BaseChat): Chat instance. + incremental (bool): Used to control whether the content is returned incrementally or in full each time. + model_name (str): The model name + + Yields: + _type_: streaming responses + """ msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong." + stream_id = f"chatcmpl-{str(uuid.uuid1())}" + previous_response = "" async for chunk in chat.stream_call(): if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( chunk, chat.skip_echo_len ) - - msg = msg.replace("\n", "\\n") - yield f"data:{msg}\n\n" + msg = msg.replace("\ufffd", "") + if incremental: + incremental_output = msg[len(previous_response) :] + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant", content=incremental_output), + ) + chunk = ChatCompletionStreamResponse( + id=stream_id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + else: + # TODO generate an openai-compatible streaming responses + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + previous_response = msg await asyncio.sleep(0.02) - + if incremental: + yield "data: [DONE]\n\n" chat.current_message.add_ai_message(msg) chat.current_message.add_view_message(msg) chat.memory.append(chat.current_message) diff --git a/pilot/openapi/api_view_model.py b/pilot/openapi/api_view_model.py index d03beec8d..60065f2f2 100644 --- a/pilot/openapi/api_view_model.py +++ b/pilot/openapi/api_view_model.py @@ -1,5 +1,7 @@ from pydantic import BaseModel, Field -from typing import TypeVar, Generic, Any +from typing import TypeVar, Generic, Any, Optional, Literal, List +import uuid +import time T = TypeVar("T") @@ -59,6 +61,11 @@ class ConversationVo(BaseModel): """ model_name: str = None + """Used to control whether the content is returned incrementally or in full each time. + If this parameter is not provided, the default is full return. + """ + incremental: bool = False + class MessageVo(BaseModel): """ @@ -83,3 +90,21 @@ class MessageVo(BaseModel): model_name """ model_name: str + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}") + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice]