mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-14 20:28:12 +00:00
style:format code style
format code style
This commit is contained in:
@@ -2,17 +2,29 @@ import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Request,
|
||||
Body,
|
||||
status,
|
||||
HTTPException,
|
||||
Response,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||
from pilot.openapi.api_v1.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
)
|
||||
from pilot.configs.config import Config
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
@@ -103,7 +115,7 @@ async def dialogue_scenes():
|
||||
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||
@@ -220,11 +232,19 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
}
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
|
||||
background=background_tasks)
|
||||
return StreamingResponse(
|
||||
no_stream_generator(chat),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks,
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
|
||||
background=background_tasks)
|
||||
return StreamingResponse(
|
||||
stream_generator(chat),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
background=background_tasks,
|
||||
)
|
||||
|
||||
|
||||
def release_model_semaphore():
|
||||
@@ -236,12 +256,15 @@ async def no_stream_generator(chat):
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
@@ -249,7 +272,9 @@ async def stream_generator(chat):
|
||||
else:
|
||||
for chunk in model_response:
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
@@ -259,4 +284,6 @@ async def stream_generator(chat):
|
||||
|
||||
|
||||
def message2Vo(message: dict, order) -> MessageVo:
|
||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||
return MessageVo(
|
||||
role=message["type"], context=message["data"]["content"], order=order
|
||||
)
|
||||
|
||||
@@ -76,18 +76,34 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)):
|
||||
async def document_upload(
|
||||
space_name: str,
|
||||
doc_name: str = Form(...),
|
||||
doc_type: str = Form(...),
|
||||
doc_file: UploadFile = File(...),
|
||||
):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
if doc_file:
|
||||
with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp:
|
||||
with NamedTemporaryFile(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
|
||||
) as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
tmp_path = tmp.name
|
||||
shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename))
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
),
|
||||
)
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename),
|
||||
request.content = (
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
),
|
||||
)
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
|
||||
@@ -25,7 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse
|
||||
from pilot.openapi.knowledge.request.knowledge_response import (
|
||||
ChunkQueryResponse,
|
||||
DocumentQueryResponse,
|
||||
)
|
||||
|
||||
knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
class ChunkQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""total: total size"""
|
||||
total: int = None
|
||||
@@ -14,9 +15,9 @@ class ChunkQueryResponse(BaseModel):
|
||||
|
||||
class DocumentQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""total: total size"""
|
||||
total: int = None
|
||||
"""page: current page"""
|
||||
page: int = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user