mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
feat:add knowledge space argument
This commit is contained in:
parent
cc57ed22ca
commit
aa3c3205a4
@ -163,7 +163,9 @@ class Config(metaclass=Singleton):
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
||||
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 100))
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
||||
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000))
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||
|
||||
|
@ -20,6 +20,7 @@ from pilot.configs.model_config import (
|
||||
|
||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -36,6 +37,18 @@ class ChatKnowledge(BaseChat):
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
)
|
||||
self.space_context = self.get_space_context(knowledge_space)
|
||||
self.top_k = (
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
if self.space_context is None
|
||||
else int(self.space_context["embedding"]["topk"])
|
||||
)
|
||||
# self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
|
||||
self.max_token = (
|
||||
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
|
||||
if self.space_context is None
|
||||
else int(self.space_context["prompt"]["max_token"])
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
@ -48,11 +61,14 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
if self.space_context:
|
||||
prompt.template_define = self.space_context["prompt"]["scene"]
|
||||
prompt.template = self.space_context["prompt"]["template"]
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[:2000]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
except NoIndexException:
|
||||
raise ValueError(
|
||||
@ -63,3 +79,7 @@ class ChatKnowledge(BaseChat):
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value()
|
||||
|
||||
def get_space_context(self, space_name):
|
||||
service = KnowledgeService()
|
||||
return service.get_space_context(space_name)
|
||||
|
@ -20,6 +20,7 @@ from pilot.server.knowledge.request.request import (
|
||||
DocumentSyncRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
)
|
||||
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
@ -54,13 +55,33 @@ def space_list(request: KnowledgeSpaceRequest):
|
||||
|
||||
@router.post("/knowledge/space/delete")
|
||||
def space_delete(request: KnowledgeSpaceRequest):
|
||||
print(f"/space/list params:")
|
||||
print(f"/space/delete params:")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.delete_space(request.name))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/arguments")
|
||||
def arguments(space_name: str):
|
||||
print(f"/knowledge/space/arguments params:")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.arguments(space_name))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/argument/save")
|
||||
def arguments_save(space_name: str, argument_request: SpaceArgumentRequest):
|
||||
print(f"/knowledge/space/argument/save params:")
|
||||
try:
|
||||
return Result.succ(
|
||||
knowledge_space_service.argument_save(space_name, argument_request)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/add")
|
||||
def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||
print(f"/document/add params: {space_name}, {request}")
|
||||
|
@ -83,3 +83,9 @@ class KnowledgeQueryResponse:
|
||||
score: float = 0.0
|
||||
"""text: raw text info"""
|
||||
text: str
|
||||
|
||||
|
||||
class SpaceArgumentRequest(BaseModel):
|
||||
"""argument: argument"""
|
||||
|
||||
argument: str
|
||||
|
@ -32,6 +32,8 @@ class SpaceQueryResponse(BaseModel):
|
||||
vector_type: str = None
|
||||
"""desc: description"""
|
||||
desc: str = None
|
||||
"""context: context"""
|
||||
context: str = None
|
||||
"""owner: owner"""
|
||||
owner: str = None
|
||||
gmt_created: str = None
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
@ -25,6 +26,7 @@ from pilot.server.knowledge.request.request import (
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentQueryRequest,
|
||||
ChunkQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
@ -102,12 +104,34 @@ class KnowledgeService:
|
||||
res.gmt_created = space.gmt_created
|
||||
res.gmt_modified = space.gmt_modified
|
||||
res.owner = space.owner
|
||||
res.context = space.context
|
||||
query = KnowledgeDocumentEntity(space=space.name)
|
||||
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
res.docs = doc_count
|
||||
responses.append(res)
|
||||
return responses
|
||||
|
||||
def arguments(self, space_name):
|
||||
query = KnowledgeSpaceEntity(name=space_name)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"there are no or more than one space called {space_name}")
|
||||
space = spaces[0]
|
||||
if space.context is None:
|
||||
context = self._build_default_context()
|
||||
else:
|
||||
context = space.context
|
||||
return json.loads(context)
|
||||
|
||||
def argument_save(self, space_name, argument_request: SpaceArgumentRequest):
|
||||
query = KnowledgeSpaceEntity(name=space_name)
|
||||
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"there are no or more than one space called {space_name}")
|
||||
space = spaces[0]
|
||||
space.context = argument_request.argument
|
||||
return knowledge_space_dao.update_knowledge_space(space)
|
||||
|
||||
"""get knowledge get_knowledge_documents"""
|
||||
|
||||
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
|
||||
@ -142,22 +166,34 @@ class KnowledgeService:
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
|
||||
space_context = self.get_space_context(space_name)
|
||||
chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except Exception:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
knowledge_source=doc.content,
|
||||
@ -287,3 +323,40 @@ class KnowledgeService:
|
||||
doc.result = "document embedding failed" + str(e)
|
||||
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
def _build_default_context(self):
|
||||
from pilot.scene.chat_knowledge.v1.prompt import (
|
||||
PROMPT_SCENE_DEFINE,
|
||||
_DEFAULT_TEMPLATE,
|
||||
)
|
||||
|
||||
context_template = {
|
||||
"embedding": {
|
||||
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
"recall_score": 0.0,
|
||||
"recall_type": "TopK",
|
||||
"model": CFG.EMBEDDING_MODEL,
|
||||
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
||||
},
|
||||
"prompt": {
|
||||
"max_token": 2000,
|
||||
"scene": PROMPT_SCENE_DEFINE,
|
||||
"template": _DEFAULT_TEMPLATE,
|
||||
},
|
||||
}
|
||||
context_template_string = json.dumps(context_template, indent=4)
|
||||
return context_template_string
|
||||
|
||||
def get_space_context(self, space_name):
|
||||
request = KnowledgeSpaceRequest()
|
||||
request.name = space_name
|
||||
spaces = self.get_knowledge_space(request)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(
|
||||
f"have not found {space_name} space or found more than one space called {space_name}"
|
||||
)
|
||||
space = spaces[0]
|
||||
if space.context is not None:
|
||||
return json.loads(spaces[0].context)
|
||||
return None
|
||||
|
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, create_engine
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.configs.config import Config
|
||||
@ -19,11 +19,12 @@ class KnowledgeSpaceEntity(Base):
|
||||
vector_type = Column(String(100))
|
||||
desc = Column(String(100))
|
||||
owner = Column(String(100))
|
||||
context = Column(Text)
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class KnowledgeSpaceDao:
|
||||
@ -88,14 +89,12 @@ class KnowledgeSpaceDao:
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
|
||||
cursor = self.conn.cursor()
|
||||
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
|
||||
cursor.execute(
|
||||
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
session.merge(space)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
||||
|
||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from chromadb.config import Settings
|
||||
from langchain.vectorstores import Chroma
|
||||
@ -20,13 +21,15 @@ class ChromaStore(VectorStoreBase):
|
||||
persist_directory=self.persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
)
|
||||
collection_metadata = {"hnsw:space": "cosine"}
|
||||
self.vector_store_client = Chroma(
|
||||
persist_directory=self.persist_dir,
|
||||
embedding_function=self.embeddings,
|
||||
client_settings=chroma_settings,
|
||||
collection_metadata=collection_metadata
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk) -> None:
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||
logger.info("ChromaStore similar search")
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
|
||||
from pilot.vector_store.milvus_store import MilvusStore
|
||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
||||
|
||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
|
||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
|
Loading…
Reference in New Issue
Block a user