fix:rag serve error

This commit is contained in:
aries_ckt 2024-03-18 19:37:06 +08:00
parent 0ed30aa44a
commit 4413ff682f
10 changed files with 168 additions and 21 deletions

View File

@ -241,4 +241,5 @@ DBGPT_LOG_LEVEL=INFO
#*******************************************************************#
#** API_KEYS **#
#*******************************************************************#
# API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas.
# API_KEYS=dbgpt

View File

@ -26,7 +26,7 @@ from dbgpt.app.openapi.api_v1.api_v1 import (
stream_generator,
)
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.client.schemas import ChatCompletionRequestBody
from dbgpt.client.schemas import ChatCompletionRequestBody, ChatMode
from dbgpt.component import logger
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
from dbgpt.model.cluster.apiserver.api import APISettings
@ -94,7 +94,7 @@ async def chat_completions(
check_chat_request(request)
if request.conv_uid is None:
request.conv_uid = str(uuid.uuid4())
if request.chat_mode == "chat_app":
if request.chat_mode == ChatMode.CHAT_APP.value:
if request.stream is False:
raise HTTPException(
status_code=400,
@ -114,27 +114,14 @@ async def chat_completions(
headers=headers,
media_type="text/event-stream",
)
elif request.chat_mode == ChatScene.ChatFlow.value():
# flow_ctx = CommonLLMHTTPRequestContext(
# conv_uid=request.conv_uid,
# chat_mode=request.chat_mode,
# user_name=request.user_name,
# sys_code=request.sys_code,
# )
# flow_req = CommonLLMHttpRequestBody(
# model=request.model,
# messages=request.chat_param,
# stream=True,
# context=flow_ctx,
# )
elif request.chat_mode == ChatMode.CHAT_AWEL_FLOW.value:
return StreamingResponse(
chat_flow_stream_wrapper(request),
headers=headers,
media_type="text/event-stream",
)
elif (
request.chat_mode is None
or request.chat_mode == ChatScene.ChatKnowledge.value()
request.chat_mode is None or request.chat_mode == ChatMode.CHAT_KNOWLEDGE.value
):
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict()

View File

@ -1,4 +1,5 @@
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union
from fastapi import File, UploadFile
@ -60,6 +61,15 @@ class ChatCompletionRequestBody(BaseModel):
)
class ChatMode(Enum):
"""Chat mode"""
CHAT_NORMAL = "chat_normal"
CHAT_APP = "chat_app"
CHAT_AWEL_FLOW = "chat_flow"
CHAT_KNOWLEDGE = "chat_knowledge"
class SpaceModel(BaseModel):
"""name: knowledge space name"""

View File

View File

@ -0,0 +1,147 @@
from datetime import datetime
from typing import Any, Dict, List, Union
from sqlalchemy import Column, DateTime, Integer, String, Text
from dbgpt.serve.rag.api.schemas import SpaceServeRequest, SpaceServeResponse
from dbgpt.storage.metadata import BaseDao, Model
class KnowledgeSpaceEntity(Model):
__tablename__ = "knowledge_space"
id = Column(Integer, primary_key=True)
name = Column(String(100))
vector_type = Column(String(100))
desc = Column(String(100))
owner = Column(String(100))
context = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeSpaceDao(BaseDao):
def create_knowledge_space(self, space: SpaceServeRequest):
"""Create knowledge space"""
session = self.get_raw_session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=space.vector_type,
desc=space.desc,
owner=space.owner,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
session.add(knowledge_space)
session.commit()
space_id = knowledge_space.id
session.close()
return space_id
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
"""Get knowledge space by query"""
session = self.get_raw_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.id == query.id
)
if query.name is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.name == query.name
)
if query.vector_type is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.vector_type == query.vector_type
)
if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.desc == query.desc
)
if query.owner is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.owner == query.owner
)
if query.gmt_created is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_created == query.gmt_created
)
if query.gmt_modified is not None:
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
)
knowledge_spaces = knowledge_spaces.order_by(
KnowledgeSpaceEntity.gmt_created.desc()
)
result = knowledge_spaces.all()
session.close()
return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
"""Update knowledge space"""
session = self.get_raw_session()
session.merge(space)
session.commit()
session.close()
return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
"""Delete knowledge space"""
session = self.get_raw_session()
if space:
session.delete(space)
session.commit()
session.close()
def from_request(
self, request: Union[SpaceServeRequest, Dict[str, Any]]
) -> KnowledgeSpaceEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.dict() if isinstance(request, SpaceServeRequest) else request
)
entity = KnowledgeSpaceEntity(**request_dict)
return entity
def to_request(self, entity: KnowledgeSpaceEntity) -> SpaceServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return SpaceServeRequest(
id=entity.id,
name=entity.name,
vector_type=entity.vector_type,
desc=entity.desc,
owner=entity.owner,
)
def to_response(self, entity: KnowledgeSpaceEntity) -> SpaceServeResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return SpaceServeResponse(
id=entity.id,
name=entity.name,
vector_type=entity.vector_type,
desc=entity.desc,
owner=entity.owner,
)

View File

@ -21,8 +21,8 @@ from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.core import Chunk
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding import EmbeddingFactory
from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeFactory, KnowledgeType

View File

@ -23,8 +23,8 @@ Client: Simple App CRUD example
async def main():
# initialize client
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)
res = await list_app(client)

View File

@ -55,6 +55,7 @@ Client: Simple Chat example
async def main():
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)

View File

@ -36,8 +36,8 @@ Client: Simple Flow CRUD example
async def main():
# initialize client
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)
res = await list_flow(client)

View File

@ -69,12 +69,13 @@ from dbgpt.client.knowledge import list_space
async def main():
# initialize client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)
# list all spaces
res = await list_space(client)
print(res)
print(res.json())
# get space
# res = await get_space(client, space_id='5')