mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 09:28:42 +00:00
feat: Add dbgpt client and add api v2
This commit is contained in:
345
dbgpt/app/openapi/api_v2.py
Normal file
345
dbgpt/app/openapi/api_v2.py
Normal file
@@ -0,0 +1,345 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastchat.protocol.api_protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import (
|
||||
CHAT_FACTORY,
|
||||
__new_conversation,
|
||||
get_chat_flow,
|
||||
get_chat_instance,
|
||||
get_executor,
|
||||
stream_generator,
|
||||
)
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.client.schemas import ChatCompletionRequestBody
|
||||
from dbgpt.component import logger
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
|
||||
from dbgpt.model.cluster.apiserver.api import APISettings
|
||||
from dbgpt.serve.agent.agents.controller import multi_agents
|
||||
from dbgpt.serve.flow.api.endpoints import get_service
|
||||
from dbgpt.serve.flow.service.service import Service as FlowService
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
|
||||
router = APIRouter()
|
||||
api_settings = APISettings()
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service=Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
Args:
|
||||
auth (Optional[HTTPAuthorizationCredentials]): The bearer token.
|
||||
service (Service): The flow service.
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = [key.strip() for key in service.config.api_keys.split(",")]
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/v2/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequestBody = Body(),
|
||||
flow_service: FlowService = Depends(get_chat_flow),
|
||||
):
|
||||
"""Chat V2 completions
|
||||
Args:
|
||||
request (ChatCompletionRequestBody): The chat request.
|
||||
flow_service (FlowService): The flow service.
|
||||
Raises:
|
||||
HTTPException: If the request is invalid.
|
||||
"""
|
||||
logger.info(
|
||||
f"chat_completions:{request.chat_mode},{request.chat_param},{request.model}"
|
||||
)
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
# check chat request
|
||||
check_chat_request(request)
|
||||
if request.conv_uid is None:
|
||||
request.conv_uid = str(uuid.uuid4())
|
||||
if request.chat_mode == "chat_app":
|
||||
if request.stream is False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chat app now not support no stream",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_request_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
return StreamingResponse(
|
||||
chat_app_stream_wrapper(
|
||||
request=request,
|
||||
),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif request.chat_mode == ChatScene.ChatFlow.value():
|
||||
# flow_ctx = CommonLLMHTTPRequestContext(
|
||||
# conv_uid=request.conv_uid,
|
||||
# chat_mode=request.chat_mode,
|
||||
# user_name=request.user_name,
|
||||
# sys_code=request.sys_code,
|
||||
# )
|
||||
# flow_req = CommonLLMHttpRequestBody(
|
||||
# model=request.model,
|
||||
# messages=request.chat_param,
|
||||
# stream=True,
|
||||
# context=flow_ctx,
|
||||
# )
|
||||
return StreamingResponse(
|
||||
chat_flow_stream_wrapper(request),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif (
|
||||
request.chat_mode is None
|
||||
or request.chat_mode == ChatScene.ChatKnowledge.value()
|
||||
):
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict()
|
||||
):
|
||||
chat: BaseChat = await get_chat_instance(request)
|
||||
|
||||
if not request.stream:
|
||||
return await no_stream_wrapper(request, chat)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat, request.incremental, request.model),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chat mode now only support chat_normal, chat_app, chat_flow, chat_knowledge",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_chat_mode",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_instance(dialogue: ChatCompletionRequestBody = Body()) -> BaseChat:
|
||||
"""
|
||||
Get chat instance
|
||||
Args:
|
||||
dialogue (OpenAPIChatCompletionRequest): The chat request.
|
||||
"""
|
||||
logger.info(f"get_chat_instance:{dialogue}")
|
||||
if not dialogue.chat_mode:
|
||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||
if not dialogue.conv_uid:
|
||||
conv_vo = __new_conversation(
|
||||
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
|
||||
)
|
||||
dialogue.conv_uid = conv_vo.conv_uid
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!")
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"user_name": dialogue.user_name,
|
||||
"sys_code": dialogue.sys_code,
|
||||
"current_user_input": dialogue.messages,
|
||||
"select_param": dialogue.chat_param,
|
||||
"model_name": dialogue.model,
|
||||
}
|
||||
chat: BaseChat = await blocking_func_to_async(
|
||||
get_executor(),
|
||||
CHAT_FACTORY.get_implementation,
|
||||
dialogue.chat_mode,
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return chat
|
||||
|
||||
|
||||
async def no_stream_wrapper(
|
||||
request: ChatCompletionRequestBody, chat: BaseChat
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
no stream wrapper
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
chat (BaseChat): chat
|
||||
"""
|
||||
with root_tracer.start_span("no_stream_generator"):
|
||||
response = await chat.nostream_call()
|
||||
msg = response.replace("\ufffd", "")
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=msg),
|
||||
)
|
||||
usage = UsageInfo()
|
||||
return ChatCompletionResponse(
|
||||
id=request.conv_uid, choices=[choice_data], model=request.model, usage=usage
|
||||
)
|
||||
|
||||
|
||||
async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None):
|
||||
"""chat app stream
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
token (APIToken): token
|
||||
"""
|
||||
async for output in multi_agents.app_agent_chat(
|
||||
conv_uid=request.conv_uid,
|
||||
gpts_name=request.chat_param,
|
||||
user_query=request.messages,
|
||||
user_code=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
):
|
||||
match = re.search(r"data:\s*({.*})", output)
|
||||
if match:
|
||||
json_str = match.group(1)
|
||||
vis = json.loads(json_str)
|
||||
vis_content = vis.get("vis", None)
|
||||
if vis_content != "[DONE]":
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=vis.get("vis", None)),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request.conv_uid,
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
created=int(time.time()),
|
||||
)
|
||||
content = (
|
||||
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
yield content
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def chat_flow_stream_wrapper(
|
||||
request: ChatCompletionRequestBody = None,
|
||||
):
|
||||
"""chat app stream
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
token (APIToken): token
|
||||
"""
|
||||
flow_service = get_chat_flow()
|
||||
flow_ctx = CommonLLMHTTPRequestContext(
|
||||
conv_uid=request.conv_uid,
|
||||
chat_mode=request.chat_mode,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
)
|
||||
flow_req = CommonLLMHttpRequestBody(
|
||||
model=request.model,
|
||||
messages=request.chat_param,
|
||||
stream=True,
|
||||
context=flow_ctx,
|
||||
)
|
||||
async for output in flow_service.chat_flow(request.chat_param, flow_req):
|
||||
if output.startswith("data: [DONE]"):
|
||||
yield output
|
||||
if output.startswith("data:"):
|
||||
output = output[len("data: ") :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request.conv_uid,
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
created=int(time.time()),
|
||||
)
|
||||
chat_completion_response = (
|
||||
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
yield chat_completion_response
|
||||
|
||||
|
||||
def check_chat_request(request: ChatCompletionRequestBody = Body()):
|
||||
"""
|
||||
Check the chat request
|
||||
Args:
|
||||
request (ChatCompletionRequestBody): The chat request.
|
||||
Raises:
|
||||
HTTPException: If the request is invalid.
|
||||
"""
|
||||
if request.chat_mode and request.chat_mode != ChatScene.ChatNormal.value():
|
||||
if request.chat_param is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chart param is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_chat_param",
|
||||
}
|
||||
},
|
||||
)
|
||||
if request.model is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "model is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_model",
|
||||
}
|
||||
},
|
||||
)
|
||||
if request.messages is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "messages is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_messages",
|
||||
}
|
||||
},
|
||||
)
|
Reference in New Issue
Block a user