feat: Add dbgpt client and add api v2

This commit is contained in:
Fangyin Cheng
2024-03-18 18:24:08 +08:00
parent 4970c9f813
commit 0ed30aa44a
39 changed files with 2663 additions and 143 deletions

345
dbgpt/app/openapi/api_v2.py Normal file
View File

@@ -0,0 +1,345 @@
import json
import re
import time
import uuid
from typing import Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastchat.protocol.api_protocol import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
UsageInfo,
)
from starlette.responses import StreamingResponse
from dbgpt.app.openapi.api_v1.api_v1 import (
CHAT_FACTORY,
__new_conversation,
get_chat_flow,
get_chat_instance,
get_executor,
stream_generator,
)
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.client.schemas import ChatCompletionRequestBody
from dbgpt.component import logger
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
from dbgpt.model.cluster.apiserver.api import APISettings
from dbgpt.serve.agent.agents.controller import multi_agents
from dbgpt.serve.flow.api.endpoints import get_service
from dbgpt.serve.flow.service.service import Service as FlowService
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import SpanType, root_tracer
router = APIRouter()
api_settings = APISettings()
get_bearer_token = HTTPBearer(auto_error=False)
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service=Depends(get_service),
) -> Optional[str]:
"""Check the api key
Args:
auth (Optional[HTTPAuthorizationCredentials]): The bearer token.
service (Service): The flow service.
"""
if service.config.api_keys:
api_keys = [key.strip() for key in service.config.api_keys.split(",")]
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
return None
@router.post("/v2/chat/completions", dependencies=[Depends(check_api_key)])
async def chat_completions(
request: ChatCompletionRequestBody = Body(),
flow_service: FlowService = Depends(get_chat_flow),
):
"""Chat V2 completions
Args:
request (ChatCompletionRequestBody): The chat request.
flow_service (FlowService): The flow service.
Raises:
HTTPException: If the request is invalid.
"""
logger.info(
f"chat_completions:{request.chat_mode},{request.chat_param},{request.model}"
)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
# check chat request
check_chat_request(request)
if request.conv_uid is None:
request.conv_uid = str(uuid.uuid4())
if request.chat_mode == "chat_app":
if request.stream is False:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "chat app now not support no stream",
"type": "invalid_request_error",
"param": None,
"code": "invalid_request_error",
}
},
)
return StreamingResponse(
chat_app_stream_wrapper(
request=request,
),
headers=headers,
media_type="text/event-stream",
)
elif request.chat_mode == ChatScene.ChatFlow.value():
# flow_ctx = CommonLLMHTTPRequestContext(
# conv_uid=request.conv_uid,
# chat_mode=request.chat_mode,
# user_name=request.user_name,
# sys_code=request.sys_code,
# )
# flow_req = CommonLLMHttpRequestBody(
# model=request.model,
# messages=request.chat_param,
# stream=True,
# context=flow_ctx,
# )
return StreamingResponse(
chat_flow_stream_wrapper(request),
headers=headers,
media_type="text/event-stream",
)
elif (
request.chat_mode is None
or request.chat_mode == ChatScene.ChatKnowledge.value()
):
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict()
):
chat: BaseChat = await get_chat_instance(request)
if not request.stream:
return await no_stream_wrapper(request, chat)
else:
return StreamingResponse(
stream_generator(chat, request.incremental, request.model),
headers=headers,
media_type="text/plain",
)
else:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "chat mode now only support chat_normal, chat_app, chat_flow, chat_knowledge",
"type": "invalid_request_error",
"param": None,
"code": "invalid_chat_mode",
}
},
)
async def get_chat_instance(dialogue: ChatCompletionRequestBody = Body()) -> BaseChat:
"""
Get chat instance
Args:
dialogue (OpenAPIChatCompletionRequest): The chat request.
"""
logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value()
if not dialogue.conv_uid:
conv_vo = __new_conversation(
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
)
dialogue.conv_uid = conv_vo.conv_uid
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!")
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_name": dialogue.user_name,
"sys_code": dialogue.sys_code,
"current_user_input": dialogue.messages,
"select_param": dialogue.chat_param,
"model_name": dialogue.model,
}
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
dialogue.chat_mode,
**{"chat_param": chat_param},
)
return chat
async def no_stream_wrapper(
request: ChatCompletionRequestBody, chat: BaseChat
) -> ChatCompletionResponse:
"""
no stream wrapper
Args:
request (OpenAPIChatCompletionRequest): request
chat (BaseChat): chat
"""
with root_tracer.start_span("no_stream_generator"):
response = await chat.nostream_call()
msg = response.replace("\ufffd", "")
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=msg),
)
usage = UsageInfo()
return ChatCompletionResponse(
id=request.conv_uid, choices=[choice_data], model=request.model, usage=usage
)
async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None):
"""chat app stream
Args:
request (OpenAPIChatCompletionRequest): request
token (APIToken): token
"""
async for output in multi_agents.app_agent_chat(
conv_uid=request.conv_uid,
gpts_name=request.chat_param,
user_query=request.messages,
user_code=request.user_name,
sys_code=request.sys_code,
):
match = re.search(r"data:\s*({.*})", output)
if match:
json_str = match.group(1)
vis = json.loads(json_str)
vis_content = vis.get("vis", None)
if vis_content != "[DONE]":
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=vis.get("vis", None)),
)
chunk = ChatCompletionStreamResponse(
id=request.conv_uid,
choices=[choice_data],
model=request.model,
created=int(time.time()),
)
content = (
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
)
yield content
yield "data: [DONE]\n\n"
async def chat_flow_stream_wrapper(
request: ChatCompletionRequestBody = None,
):
"""chat app stream
Args:
request (OpenAPIChatCompletionRequest): request
token (APIToken): token
"""
flow_service = get_chat_flow()
flow_ctx = CommonLLMHTTPRequestContext(
conv_uid=request.conv_uid,
chat_mode=request.chat_mode,
user_name=request.user_name,
sys_code=request.sys_code,
)
flow_req = CommonLLMHttpRequestBody(
model=request.model,
messages=request.chat_param,
stream=True,
context=flow_ctx,
)
async for output in flow_service.chat_flow(request.chat_param, flow_req):
if output.startswith("data: [DONE]"):
yield output
if output.startswith("data:"):
output = output[len("data: ") :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=output),
)
chunk = ChatCompletionStreamResponse(
id=request.conv_uid,
choices=[choice_data],
model=request.model,
created=int(time.time()),
)
chat_completion_response = (
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
)
yield chat_completion_response
def check_chat_request(request: ChatCompletionRequestBody = Body()):
"""
Check the chat request
Args:
request (ChatCompletionRequestBody): The chat request.
Raises:
HTTPException: If the request is invalid.
"""
if request.chat_mode and request.chat_mode != ChatScene.ChatNormal.value():
if request.chat_param is None:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "chart param is None",
"type": "invalid_request_error",
"param": None,
"code": "invalid_chat_param",
}
},
)
if request.model is None:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "model is None",
"type": "invalid_request_error",
"param": None,
"code": "invalid_model",
}
},
)
if request.messages is None:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "messages is None",
"type": "invalid_request_error",
"param": None,
"code": "invalid_messages",
}
},
)