mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
feat: add Client and API v2 (#1316)
# Description 1. Provide /api/v2 for DB-GPT 2. Add DBGPT Python Client for Chat, Flow, App, Knowledge, Including: - Chat - Create - Update - Delete - Get - List 3. Add examples in `examples/client/` 4. Add API Reference document # How Has This Been Tested? ## Test Chat Normal ### Curl 1. set `API_KEYS=dbgpt` in `.env` 2. `python dbgpt/app/dbgpt_server.py` 3. test with curl ``` DBGPT_API_KEY=dbgpt curl -X POST "http://localhost:5000/api/v2/chat/completions" \ -H "Authorization: Bearer $DBGPT_API_KEY" \ -H "accept: application/json" \ -H "Content-Type: application/json" \ -d "{\"messages\":\"Hello\",\"model\":\"chatgpt_proxyllm\"}" ``` ``` data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "Hello"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "!"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": " How"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": " can"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": " I"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": " assist"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": " you"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": " today"}}]} data: {"id": "chatcmpl-ab5fd180-e699-11ee-8388-acde48001122", "model": "chatgpt_proxyllm", "choices": [{"index": 0, "delta": {"role": "assistant", "content": "?"}}]} data: [DONE] ``` ### Python ```python from dbgpt.client import Client DBGPT_API_KEY = "dbgpt" client = Client(api_key=DBGPT_API_KEY) # stream async for data in client.chat_stream( model="chatgpt_proxyllm", messages="hello", ): print(data) # no stream await client.chat(model="chatgpt_proxyllm", messages="hello") ``` ## Test Chat App ### Curl test with curl ``` DBGPT_API_KEY=dbgpt APP_CODE={YOUR_APP_CODE} curl -X POST "http://localhost:5000/api/v2/chat/completions" \ -H "Authorization: Bearer $DBGPT_API_KEY" \ -H "accept: application/json" \ -H "Content-Type: application/json" \ -d "{\"messages\":\"Hello\",\"model\":\"chatgpt_proxyllm\", \"chat_mode\": \"chat_app\", \"chat_param\": \"$APP_CODE\"}" ``` ### Python ```python from dbgpt.client import Client DBGPT_API_KEY = "dbgpt" client = Client(api_key=DBGPT_API_KEY) APP_CODE="{YOUR_APP_CODE}" async for data in client.chat_stream( model="chatgpt_proxyllm", messages="hello", chat_mode="chat_app", chat_param=APP_CODE ): print(data) ``` # Snapshots: Include snapshots for easier review. # Checklist: - [x] My code follows the style guidelines of this project - [x] I have already rebased the commits and make the commit message conform to the project standard. - [x] I have performed a self-review of my own code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have made corresponding changes to the documentation - [x] Any dependent changes have been merged and published in downstream modules
This commit is contained in:
@@ -89,13 +89,17 @@ def mount_routers(app: FastAPI):
|
||||
router as api_editor_route_v1,
|
||||
)
|
||||
from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
from dbgpt.app.openapi.api_v2 import router as api_v2
|
||||
from dbgpt.serve.agent.app.controller import router as gpts_v1
|
||||
from dbgpt.serve.agent.app.endpoints import router as app_v2
|
||||
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_v2, prefix="/api", tags=["ChatV2"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
|
||||
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
|
||||
app.include_router(gpts_v1, prefix="/api", tags=["GptsApp"])
|
||||
app.include_router(app_v2, prefix="/api", tags=["App"])
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
|
||||
|
@@ -2,12 +2,12 @@
|
||||
"""
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkEntity
|
||||
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
|
||||
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
|
||||
from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity
|
||||
from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
|
||||
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity
|
||||
from dbgpt.storage.chat_history.chat_history_db import (
|
||||
ChatHistoryEntity,
|
||||
ChatHistoryMessageEntity,
|
||||
|
@@ -5,6 +5,8 @@ from dbgpt.component import SystemApp
|
||||
def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
"""Register serve apps"""
|
||||
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
|
||||
if cfg.API_KEYS:
|
||||
system_app.config.set("dbgpt.app.global.api_keys", cfg.API_KEYS)
|
||||
|
||||
# ################################ Prompt Serve Register Begin ######################################
|
||||
from dbgpt.serve.prompt.serve import (
|
||||
@@ -42,4 +44,12 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
|
||||
# Register serve app
|
||||
system_app.register(FlowServe)
|
||||
|
||||
from dbgpt.serve.rag.serve import (
|
||||
SERVE_CONFIG_KEY_PREFIX as RAG_SERVE_CONFIG_KEY_PREFIX,
|
||||
)
|
||||
from dbgpt.serve.rag.serve import Serve as RagServe
|
||||
|
||||
# Register serve app
|
||||
system_app.register(RagServe)
|
||||
# ################################ AWEL Flow Serve Register End ########################################
|
||||
|
@@ -4,7 +4,7 @@ import shutil
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, Form, UploadFile
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
@@ -16,7 +16,6 @@ from dbgpt.app.knowledge.request.request import (
|
||||
KnowledgeDocumentRequest,
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeSpaceRequest,
|
||||
KnowledgeSyncRequest,
|
||||
SpaceArgumentRequest,
|
||||
)
|
||||
from dbgpt.app.knowledge.request.response import KnowledgeQueryResponse
|
||||
@@ -31,6 +30,8 @@ from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
|
||||
from dbgpt.serve.rag.service.service import Service
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
@@ -44,6 +45,11 @@ router = APIRouter()
|
||||
knowledge_space_service = KnowledgeService()
|
||||
|
||||
|
||||
def get_rag_service() -> Service:
|
||||
"""Get Rag Service."""
|
||||
return Service.get_instance(CFG.SYSTEM_APP)
|
||||
|
||||
|
||||
@router.post("/knowledge/space/add")
|
||||
def space_add(request: KnowledgeSpaceRequest):
|
||||
print(f"/space/add params: {request}")
|
||||
@@ -226,12 +232,20 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync_batch")
|
||||
def batch_document_sync(space_name: str, request: List[KnowledgeSyncRequest]):
|
||||
def batch_document_sync(
|
||||
space_name: str,
|
||||
request: List[KnowledgeSyncRequest],
|
||||
service: Service = Depends(get_rag_service),
|
||||
):
|
||||
logger.info(f"Received params: {space_name}, {request}")
|
||||
try:
|
||||
doc_ids = knowledge_space_service.batch_document_sync(
|
||||
space_name=space_name, sync_requests=request
|
||||
)
|
||||
space = service.get({"name": space_name})
|
||||
for sync_request in request:
|
||||
sync_request.space_id = space.id
|
||||
doc_ids = service.sync_document(requests=request)
|
||||
# doc_ids = service.sync_document(
|
||||
# space_name=space_name, sync_requests=request
|
||||
# )
|
||||
return Result.succ({"tasks": doc_ids})
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, func
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.serve.conversation.api.schemas import ServeRequest
|
||||
from dbgpt.serve.rag.api.schemas import DocumentServeRequest, DocumentServeResponse
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
CFG = Config()
|
||||
@@ -218,3 +220,70 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
knowledge_documents.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def from_request(
|
||||
self, request: Union[ServeRequest, Dict[str, Any]]
|
||||
) -> KnowledgeDocumentEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[ServeRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
request.dict() if isinstance(request, DocumentServeRequest) else request
|
||||
)
|
||||
entity = KnowledgeDocumentEntity(**request_dict)
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: KnowledgeDocumentEntity) -> DocumentServeResponse:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
return DocumentServeResponse(
|
||||
id=entity.id,
|
||||
doc_name=entity.doc_name,
|
||||
doc_type=entity.doc_type,
|
||||
space=entity.space,
|
||||
chunk_size=entity.chunk_size,
|
||||
status=entity.status,
|
||||
last_sync=entity.last_sync,
|
||||
content=entity.content,
|
||||
result=entity.result,
|
||||
vector_ids=entity.vector_ids,
|
||||
summary=entity.summary,
|
||||
gmt_created=entity.gmt_created,
|
||||
gmt_modified=entity.gmt_modified,
|
||||
)
|
||||
|
||||
def to_response(self, entity: KnowledgeDocumentEntity) -> DocumentServeResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
return DocumentServeResponse(
|
||||
id=entity.id,
|
||||
doc_name=entity.doc_name,
|
||||
doc_type=entity.doc_type,
|
||||
space=entity.space,
|
||||
chunk_size=entity.chunk_size,
|
||||
status=entity.status,
|
||||
last_sync=entity.last_sync,
|
||||
content=entity.content,
|
||||
result=entity.result,
|
||||
vector_ids=entity.vector_ids,
|
||||
summary=entity.summary,
|
||||
gmt_created=entity.gmt_created,
|
||||
gmt_modified=entity.gmt_modified,
|
||||
)
|
||||
|
@@ -17,6 +17,8 @@ class KnowledgeQueryRequest(BaseModel):
|
||||
class KnowledgeSpaceRequest(BaseModel):
|
||||
"""name: knowledge space name"""
|
||||
|
||||
"""vector_type: vector type"""
|
||||
id: int = None
|
||||
name: str = None
|
||||
"""vector_type: vector type"""
|
||||
vector_type: str = None
|
||||
@@ -37,9 +39,6 @@ class KnowledgeDocumentRequest(BaseModel):
|
||||
"""content: content"""
|
||||
source: str = None
|
||||
|
||||
"""text_chunk_size: text_chunk_size"""
|
||||
# text_chunk_size: int
|
||||
|
||||
|
||||
class DocumentQueryRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
@@ -80,20 +79,6 @@ class DocumentSyncRequest(BaseModel):
|
||||
chunk_overlap: Optional[int] = None
|
||||
|
||||
|
||||
class KnowledgeSyncRequest(BaseModel):
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_id: int
|
||||
|
||||
"""model_name: model name"""
|
||||
model_name: Optional[str] = None
|
||||
|
||||
"""chunk_parameters: chunk parameters
|
||||
"""
|
||||
chunk_parameters: ChunkParameters
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
"""id: id"""
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
@@ -17,7 +16,6 @@ from dbgpt.app.knowledge.request.request import (
|
||||
DocumentSyncRequest,
|
||||
KnowledgeDocumentRequest,
|
||||
KnowledgeSpaceRequest,
|
||||
KnowledgeSyncRequest,
|
||||
SpaceArgumentRequest,
|
||||
)
|
||||
from dbgpt.app.knowledge.request.response import (
|
||||
@@ -25,7 +23,6 @@ from dbgpt.app.knowledge.request.response import (
|
||||
DocumentQueryResponse,
|
||||
SpaceQueryResponse,
|
||||
)
|
||||
from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.core import Chunk
|
||||
@@ -38,8 +35,11 @@ from dbgpt.rag.text_splitter.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from dbgpt.serve.rag.service.service import Service, SyncStatus
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
@@ -53,13 +53,6 @@ logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
TODO = "TODO"
|
||||
FAILED = "FAILED"
|
||||
RUNNING = "RUNNING"
|
||||
FINISHED = "FINISHED"
|
||||
|
||||
|
||||
# default summary max iteration call with llm.
|
||||
DEFAULT_SUMMARY_MAX_ITERATION = 5
|
||||
# default summary concurrency call with llm.
|
||||
@@ -88,8 +81,8 @@ class KnowledgeService:
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) > 0:
|
||||
raise Exception(f"space name:{request.name} have already named")
|
||||
knowledge_space_dao.create_knowledge_space(request)
|
||||
return True
|
||||
space_id = knowledge_space_dao.create_knowledge_space(request)
|
||||
return space_id
|
||||
|
||||
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
|
||||
"""create knowledge document
|
||||
@@ -199,7 +192,9 @@ class KnowledgeService:
|
||||
return res
|
||||
|
||||
def batch_document_sync(
|
||||
self, space_name, sync_requests: List[KnowledgeSyncRequest]
|
||||
self,
|
||||
space_name,
|
||||
sync_requests: List[KnowledgeSyncRequest],
|
||||
) -> List[int]:
|
||||
"""batch sync knowledge document chunk into vector store
|
||||
Args:
|
||||
|
@@ -1,93 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Model):
|
||||
__tablename__ = "knowledge_space"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100))
|
||||
vector_type = Column(String(100))
|
||||
desc = Column(String(100))
|
||||
owner = Column(String(100))
|
||||
context = Column(Text)
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class KnowledgeSpaceDao(BaseDao):
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.get_raw_session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||
desc=space.desc,
|
||||
owner=space.owner,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_space)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.get_raw_session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.id == query.id
|
||||
)
|
||||
if query.name is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.name == query.name
|
||||
)
|
||||
if query.vector_type is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.vector_type == query.vector_type
|
||||
)
|
||||
if query.desc is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.desc == query.desc
|
||||
)
|
||||
if query.owner is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.owner == query.owner
|
||||
)
|
||||
if query.gmt_created is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_created == query.gmt_created
|
||||
)
|
||||
if query.gmt_modified is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
|
||||
)
|
||||
|
||||
knowledge_spaces = knowledge_spaces.order_by(
|
||||
KnowledgeSpaceEntity.gmt_created.desc()
|
||||
)
|
||||
result = knowledge_spaces.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_raw_session()
|
||||
session.merge(space)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
||||
|
||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_raw_session()
|
||||
if space:
|
||||
session.delete(space)
|
||||
session.commit()
|
||||
session.close()
|
@@ -13,11 +13,8 @@ from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.app.openapi.api_view_model import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatSceneVo,
|
||||
ConversationVo,
|
||||
DeltaMessage,
|
||||
MessageVo,
|
||||
Result,
|
||||
)
|
||||
@@ -25,6 +22,11 @@ from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
|
||||
from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
)
|
||||
from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
|
||||
from dbgpt.model.base import FlatSupportedModel
|
||||
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||
@@ -439,7 +441,6 @@ async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
span = root_tracer.start_span("stream_generator")
|
||||
msg = "[LLM_ERROR]: llm server has no output, maybe your prompt template is wrong."
|
||||
|
||||
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
|
||||
previous_response = ""
|
||||
async for chunk in chat.stream_call():
|
||||
if chunk:
|
||||
@@ -451,7 +452,7 @@ async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
delta=DeltaMessage(role="assistant", content=incremental_output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=stream_id, choices=[choice_data], model=model_name
|
||||
id=chat.chat_session_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
|
333
dbgpt/app/openapi/api_v2.py
Normal file
333
dbgpt/app/openapi/api_v2.py
Normal file
@@ -0,0 +1,333 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import (
|
||||
CHAT_FACTORY,
|
||||
__new_conversation,
|
||||
get_chat_flow,
|
||||
get_chat_instance,
|
||||
get_executor,
|
||||
stream_generator,
|
||||
)
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.client.schema import ChatCompletionRequestBody, ChatMode
|
||||
from dbgpt.component import logger
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
|
||||
from dbgpt.core.schema.api import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
from dbgpt.model.cluster.apiserver.api import APISettings
|
||||
from dbgpt.serve.agent.agents.controller import multi_agents
|
||||
from dbgpt.serve.flow.api.endpoints import get_service
|
||||
from dbgpt.serve.flow.service.service import Service as FlowService
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
|
||||
router = APIRouter()
|
||||
api_settings = APISettings()
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service=Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
Args:
|
||||
auth (Optional[HTTPAuthorizationCredentials]): The bearer token.
|
||||
service (Service): The flow service.
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = [key.strip() for key in service.config.api_keys.split(",")]
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/v2/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequestBody = Body(),
|
||||
):
|
||||
"""Chat V2 completions
|
||||
Args:
|
||||
request (ChatCompletionRequestBody): The chat request.
|
||||
flow_service (FlowService): The flow service.
|
||||
Raises:
|
||||
HTTPException: If the request is invalid.
|
||||
"""
|
||||
logger.info(
|
||||
f"chat_completions:{request.chat_mode},{request.chat_param},{request.model}"
|
||||
)
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
# check chat request
|
||||
check_chat_request(request)
|
||||
if request.conv_uid is None:
|
||||
request.conv_uid = str(uuid.uuid4())
|
||||
if request.chat_mode == ChatMode.CHAT_APP.value:
|
||||
if request.stream is False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chat app now not support no stream",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_request_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
return StreamingResponse(
|
||||
chat_app_stream_wrapper(
|
||||
request=request,
|
||||
),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif request.chat_mode == ChatMode.CHAT_AWEL_FLOW.value:
|
||||
return StreamingResponse(
|
||||
chat_flow_stream_wrapper(request),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
elif (
|
||||
request.chat_mode is None
|
||||
or request.chat_mode == ChatMode.CHAT_NORMAL.value
|
||||
or request.chat_mode == ChatMode.CHAT_KNOWLEDGE.value
|
||||
):
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict()
|
||||
):
|
||||
chat: BaseChat = await get_chat_instance(request)
|
||||
|
||||
if not request.stream:
|
||||
return await no_stream_wrapper(request, chat)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat, request.incremental, request.model),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chat mode now only support chat_normal, chat_app, chat_flow, chat_knowledge",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_chat_mode",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_instance(dialogue: ChatCompletionRequestBody = Body()) -> BaseChat:
|
||||
"""
|
||||
Get chat instance
|
||||
Args:
|
||||
dialogue (OpenAPIChatCompletionRequest): The chat request.
|
||||
"""
|
||||
logger.info(f"get_chat_instance:{dialogue}")
|
||||
if not dialogue.chat_mode:
|
||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||
if not dialogue.conv_uid:
|
||||
conv_vo = __new_conversation(
|
||||
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
|
||||
)
|
||||
dialogue.conv_uid = conv_vo.conv_uid
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!")
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"user_name": dialogue.user_name,
|
||||
"sys_code": dialogue.sys_code,
|
||||
"current_user_input": dialogue.messages,
|
||||
"select_param": dialogue.chat_param,
|
||||
"model_name": dialogue.model,
|
||||
}
|
||||
chat: BaseChat = await blocking_func_to_async(
|
||||
get_executor(),
|
||||
CHAT_FACTORY.get_implementation,
|
||||
dialogue.chat_mode,
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return chat
|
||||
|
||||
|
||||
async def no_stream_wrapper(
|
||||
request: ChatCompletionRequestBody, chat: BaseChat
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
no stream wrapper
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
chat (BaseChat): chat
|
||||
"""
|
||||
with root_tracer.start_span("no_stream_generator"):
|
||||
response = await chat.nostream_call()
|
||||
msg = response.replace("\ufffd", "")
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=msg),
|
||||
)
|
||||
usage = UsageInfo()
|
||||
return ChatCompletionResponse(
|
||||
id=request.conv_uid, choices=[choice_data], model=request.model, usage=usage
|
||||
)
|
||||
|
||||
|
||||
async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None):
|
||||
"""chat app stream
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
token (APIToken): token
|
||||
"""
|
||||
async for output in multi_agents.app_agent_chat(
|
||||
conv_uid=request.conv_uid,
|
||||
gpts_name=request.chat_param,
|
||||
user_query=request.messages,
|
||||
user_code=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
):
|
||||
match = re.search(r"data:\s*({.*})", output)
|
||||
if match:
|
||||
json_str = match.group(1)
|
||||
vis = json.loads(json_str)
|
||||
vis_content = vis.get("vis", None)
|
||||
if vis_content != "[DONE]":
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=vis.get("vis", None)),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request.conv_uid,
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
created=int(time.time()),
|
||||
)
|
||||
content = (
|
||||
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
yield content
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def chat_flow_stream_wrapper(
|
||||
request: ChatCompletionRequestBody = None,
|
||||
):
|
||||
"""chat app stream
|
||||
Args:
|
||||
request (OpenAPIChatCompletionRequest): request
|
||||
token (APIToken): token
|
||||
"""
|
||||
flow_service = get_chat_flow()
|
||||
flow_ctx = CommonLLMHTTPRequestContext(
|
||||
conv_uid=request.conv_uid,
|
||||
chat_mode=request.chat_mode,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
)
|
||||
flow_req = CommonLLMHttpRequestBody(
|
||||
model=request.model,
|
||||
messages=request.chat_param,
|
||||
stream=True,
|
||||
context=flow_ctx,
|
||||
)
|
||||
async for output in flow_service.chat_flow(request.chat_param, flow_req):
|
||||
if output.startswith("data: [DONE]"):
|
||||
yield output
|
||||
if output.startswith("data:"):
|
||||
output = output[len("data: ") :]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=output),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request.conv_uid,
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
created=int(time.time()),
|
||||
)
|
||||
chat_completion_response = (
|
||||
f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
yield chat_completion_response
|
||||
|
||||
|
||||
def check_chat_request(request: ChatCompletionRequestBody = Body()):
|
||||
"""
|
||||
Check the chat request
|
||||
Args:
|
||||
request (ChatCompletionRequestBody): The chat request.
|
||||
Raises:
|
||||
HTTPException: If the request is invalid.
|
||||
"""
|
||||
if request.chat_mode and request.chat_mode != ChatScene.ChatNormal.value():
|
||||
if request.chat_param is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chart param is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_chat_param",
|
||||
}
|
||||
},
|
||||
)
|
||||
if request.model is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "model is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_model",
|
||||
}
|
||||
},
|
||||
)
|
||||
if request.messages is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "messages is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_messages",
|
||||
}
|
||||
},
|
||||
)
|
@@ -89,21 +89,3 @@ class MessageVo(BaseModel):
|
||||
model_name
|
||||
"""
|
||||
model_name: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
@@ -14,14 +14,22 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造, 回答的时候最好按照1.2.3.点进行总结。
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下给出的已知信息, 准守规范约束,专业、简要回答用户的问题.
|
||||
规范约束:
|
||||
1.如果已知信息包含的图片、链接、表格、代码块等特殊markdown标签格式的信息,确保在答案中包含原文这些图片、链接、表格和代码标签,不要丢弃不要修改,如:图片格式:, 链接格式:[xxx](xxx), 表格格式:|xxx|xxx|xxx|, 代码格式:```xxx```.
|
||||
2.如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造.
|
||||
3.回答的时候最好按照1.2.3.点进行总结.
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question},请使用和用户相同的语言进行回答.
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. When answering, it is best to summarize according to points 1.2.3.
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions.
|
||||
constraints:
|
||||
1.Ensure to include original markdown formatting elements such as images, links, tables, or code blocks without alteration in the response if they are present in the provided information.
|
||||
For example, image format should be , link format [xxx](xxx), table format should be represented with |xxx|xxx|xxx|, and code format with xxx.
|
||||
2.If the information available in the knowledge base is insufficient to answer the question, state clearly: "The content provided in the knowledge base is not enough to answer this question," and avoid making up answers.
|
||||
3.When responding, it is best to summarize the points in the order of 1, 2, 3.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
|
Reference in New Issue
Block a user