mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
feat: Run AWEL flow in CLI (#1341)
This commit is contained in:
@@ -372,7 +372,7 @@ async def chat_completions(
|
||||
incremental=dialogue.incremental,
|
||||
)
|
||||
return StreamingResponse(
|
||||
flow_service.chat_flow(dialogue.select_param, flow_req),
|
||||
flow_service.chat_stream_flow_str(dialogue.select_param, flow_req),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
@@ -2,11 +2,11 @@ import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import StreamingResponse
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import (
|
||||
CHAT_FACTORY,
|
||||
@@ -27,6 +27,7 @@ from dbgpt.core.schema.api import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from dbgpt.model.cluster.apiserver.api import APISettings
|
||||
@@ -114,11 +115,14 @@ async def chat_completions(
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif request.chat_mode == ChatMode.CHAT_AWEL_FLOW.value:
|
||||
return StreamingResponse(
|
||||
chat_flow_stream_wrapper(request),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if not request.stream:
|
||||
return await chat_flow_wrapper(request)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
chat_flow_stream_wrapper(request),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif (
|
||||
request.chat_mode is None
|
||||
or request.chat_mode == ChatMode.CHAT_NORMAL.value
|
||||
@@ -244,35 +248,43 @@ async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None):
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def chat_flow_wrapper(request: ChatCompletionRequestBody):
|
||||
flow_service = get_chat_flow()
|
||||
flow_req = CommonLLMHttpRequestBody(**request.dict())
|
||||
flow_uid = request.chat_param
|
||||
output = await flow_service.safe_chat_flow(flow_uid, flow_req)
|
||||
if not output.success:
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=output.text, code=output.error_code).dict(),
|
||||
status_code=400,
|
||||
)
|
||||
else:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=output.text),
|
||||
)
|
||||
if output.usage:
|
||||
usage = UsageInfo(**output.usage)
|
||||
else:
|
||||
usage = UsageInfo()
|
||||
return ChatCompletionResponse(
|
||||
id=request.conv_uid, choices=[choice_data], model=request.model, usage=usage
|
||||
)
|
||||
|
||||
|
||||
async def chat_flow_stream_wrapper(
|
||||
request: ChatCompletionRequestBody = None,
|
||||
):
|
||||
request: ChatCompletionRequestBody,
|
||||
) -> AsyncIterator[str]:
|
||||
"""chat app stream
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
token (APIToken): token
|
||||
"""
|
||||
flow_service = get_chat_flow()
|
||||
flow_req = CommonLLMHttpRequestBody(**request.dict())
|
||||
async for output in flow_service.chat_flow(request.chat_param, flow_req):
|
||||
if output.startswith("data: [DONE]"):
|
||||
yield output
|
||||
if output.startswith("data:"):
|
||||
output = output[len("data: ") :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request.conv_uid,
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
created=int(time.time()),
|
||||
)
|
||||
chat_completion_response = (
|
||||
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
yield chat_completion_response
|
||||
flow_uid = request.chat_param
|
||||
|
||||
async for output in flow_service.chat_stream_openai(flow_uid, flow_req):
|
||||
yield output
|
||||
|
||||
|
||||
def check_chat_request(request: ChatCompletionRequestBody = Body()):
|
||||
|
Reference in New Issue
Block a user