feat: Run AWEL flow in CLI (#1341)

This commit is contained in:
Fangyin Cheng
2024-03-27 12:50:05 +08:00
committed by GitHub
parent 340a9fbc35
commit 3a7a2cbbb8
42 changed files with 1454 additions and 422 deletions

View File

@@ -372,7 +372,7 @@ async def chat_completions(
incremental=dialogue.incremental,
)
return StreamingResponse(
flow_service.chat_flow(dialogue.select_param, flow_req),
flow_service.chat_stream_flow_str(dialogue.select_param, flow_req),
headers=headers,
media_type="text/event-stream",
)

View File

@@ -2,11 +2,11 @@ import json
import re
import time
import uuid
from typing import Optional
from typing import AsyncIterator, Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette.responses import StreamingResponse
from starlette.responses import JSONResponse, StreamingResponse
from dbgpt.app.openapi.api_v1.api_v1 import (
CHAT_FACTORY,
@@ -27,6 +27,7 @@ from dbgpt.core.schema.api import (
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ErrorResponse,
UsageInfo,
)
from dbgpt.model.cluster.apiserver.api import APISettings
@@ -114,11 +115,14 @@ async def chat_completions(
media_type="text/event-stream",
)
elif request.chat_mode == ChatMode.CHAT_AWEL_FLOW.value:
return StreamingResponse(
chat_flow_stream_wrapper(request),
headers=headers,
media_type="text/event-stream",
)
if not request.stream:
return await chat_flow_wrapper(request)
else:
return StreamingResponse(
chat_flow_stream_wrapper(request),
headers=headers,
media_type="text/event-stream",
)
elif (
request.chat_mode is None
or request.chat_mode == ChatMode.CHAT_NORMAL.value
@@ -244,35 +248,43 @@ async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None):
yield "data: [DONE]\n\n"
async def chat_flow_wrapper(request: ChatCompletionRequestBody):
flow_service = get_chat_flow()
flow_req = CommonLLMHttpRequestBody(**request.dict())
flow_uid = request.chat_param
output = await flow_service.safe_chat_flow(flow_uid, flow_req)
if not output.success:
return JSONResponse(
ErrorResponse(message=output.text, code=output.error_code).dict(),
status_code=400,
)
else:
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=output.text),
)
if output.usage:
usage = UsageInfo(**output.usage)
else:
usage = UsageInfo()
return ChatCompletionResponse(
id=request.conv_uid, choices=[choice_data], model=request.model, usage=usage
)
async def chat_flow_stream_wrapper(
request: ChatCompletionRequestBody = None,
):
request: ChatCompletionRequestBody,
) -> AsyncIterator[str]:
"""chat app stream
Args:
request (OpenAPIChatCompletionRequest): request
token (APIToken): token
"""
flow_service = get_chat_flow()
flow_req = CommonLLMHttpRequestBody(**request.dict())
async for output in flow_service.chat_flow(request.chat_param, flow_req):
if output.startswith("data: [DONE]"):
yield output
if output.startswith("data:"):
output = output[len("data: ") :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=output),
)
chunk = ChatCompletionStreamResponse(
id=request.conv_uid,
choices=[choice_data],
model=request.model,
created=int(time.time()),
)
chat_completion_response = (
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
)
yield chat_completion_response
flow_uid = request.chat_param
async for output in flow_service.chat_stream_openai(flow_uid, flow_req):
yield output
def check_chat_request(request: ChatCompletionRequestBody = Body()):