mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-11 13:12:18 +00:00
feat:add knowledge space argument
This commit is contained in:
parent
cc57ed22ca
commit
aa3c3205a4
@ -163,7 +163,9 @@ class Config(metaclass=Singleton):
|
|||||||
### EMBEDDING Configuration
|
### EMBEDDING Configuration
|
||||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
||||||
|
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 100))
|
||||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
||||||
|
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000))
|
||||||
### SUMMARY_CONFIG Configuration
|
### SUMMARY_CONFIG Configuration
|
||||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from pilot.configs.model_config import (
|
|||||||
|
|
||||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -36,6 +37,18 @@ class ChatKnowledge(BaseChat):
|
|||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
)
|
)
|
||||||
|
self.space_context = self.get_space_context(knowledge_space)
|
||||||
|
self.top_k = (
|
||||||
|
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||||
|
if self.space_context is None
|
||||||
|
else int(self.space_context["embedding"]["topk"])
|
||||||
|
)
|
||||||
|
# self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
|
||||||
|
self.max_token = (
|
||||||
|
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
|
||||||
|
if self.space_context is None
|
||||||
|
else int(self.space_context["prompt"]["max_token"])
|
||||||
|
)
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": knowledge_space,
|
"vector_store_name": knowledge_space,
|
||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
@ -48,11 +61,14 @@ class ChatKnowledge(BaseChat):
|
|||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
try:
|
try:
|
||||||
|
if self.space_context:
|
||||||
|
prompt.template_define = self.space_context["prompt"]["scene"]
|
||||||
|
prompt.template = self.space_context["prompt"]["template"]
|
||||||
docs = self.knowledge_embedding_client.similar_search(
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
self.current_user_input, self.top_k
|
||||||
)
|
)
|
||||||
context = [d.page_content for d in docs]
|
context = [d.page_content for d in docs]
|
||||||
context = context[:2000]
|
context = context[: self.max_token]
|
||||||
input_values = {"context": context, "question": self.current_user_input}
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
except NoIndexException:
|
except NoIndexException:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -63,3 +79,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatKnowledge.value()
|
return ChatScene.ChatKnowledge.value()
|
||||||
|
|
||||||
|
def get_space_context(self, space_name):
|
||||||
|
service = KnowledgeService()
|
||||||
|
return service.get_space_context(space_name)
|
||||||
|
@ -20,6 +20,7 @@ from pilot.server.knowledge.request.request import (
|
|||||||
DocumentSyncRequest,
|
DocumentSyncRequest,
|
||||||
ChunkQueryRequest,
|
ChunkQueryRequest,
|
||||||
DocumentQueryRequest,
|
DocumentQueryRequest,
|
||||||
|
SpaceArgumentRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
@ -54,13 +55,33 @@ def space_list(request: KnowledgeSpaceRequest):
|
|||||||
|
|
||||||
@router.post("/knowledge/space/delete")
|
@router.post("/knowledge/space/delete")
|
||||||
def space_delete(request: KnowledgeSpaceRequest):
|
def space_delete(request: KnowledgeSpaceRequest):
|
||||||
print(f"/space/list params:")
|
print(f"/space/delete params:")
|
||||||
try:
|
try:
|
||||||
return Result.succ(knowledge_space_service.delete_space(request.name))
|
return Result.succ(knowledge_space_service.delete_space(request.name))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.faild(code="E000X", msg=f"space list error {e}")
|
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledge/{space_name}/arguments")
|
||||||
|
def arguments(space_name: str):
|
||||||
|
print(f"/knowledge/space/arguments params:")
|
||||||
|
try:
|
||||||
|
return Result.succ(knowledge_space_service.arguments(space_name))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledge/{space_name}/argument/save")
|
||||||
|
def arguments_save(space_name: str, argument_request: SpaceArgumentRequest):
|
||||||
|
print(f"/knowledge/space/argument/save params:")
|
||||||
|
try:
|
||||||
|
return Result.succ(
|
||||||
|
knowledge_space_service.argument_save(space_name, argument_request)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"space list error {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/add")
|
@router.post("/knowledge/{space_name}/document/add")
|
||||||
def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||||
print(f"/document/add params: {space_name}, {request}")
|
print(f"/document/add params: {space_name}, {request}")
|
||||||
|
@ -83,3 +83,9 @@ class KnowledgeQueryResponse:
|
|||||||
score: float = 0.0
|
score: float = 0.0
|
||||||
"""text: raw text info"""
|
"""text: raw text info"""
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceArgumentRequest(BaseModel):
|
||||||
|
"""argument: argument"""
|
||||||
|
|
||||||
|
argument: str
|
||||||
|
@ -32,6 +32,8 @@ class SpaceQueryResponse(BaseModel):
|
|||||||
vector_type: str = None
|
vector_type: str = None
|
||||||
"""desc: description"""
|
"""desc: description"""
|
||||||
desc: str = None
|
desc: str = None
|
||||||
|
"""context: context"""
|
||||||
|
context: str = None
|
||||||
"""owner: owner"""
|
"""owner: owner"""
|
||||||
owner: str = None
|
owner: str = None
|
||||||
gmt_created: str = None
|
gmt_created: str = None
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ from pilot.server.knowledge.request.request import (
|
|||||||
KnowledgeDocumentRequest,
|
KnowledgeDocumentRequest,
|
||||||
DocumentQueryRequest,
|
DocumentQueryRequest,
|
||||||
ChunkQueryRequest,
|
ChunkQueryRequest,
|
||||||
|
SpaceArgumentRequest,
|
||||||
)
|
)
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@ -102,12 +104,34 @@ class KnowledgeService:
|
|||||||
res.gmt_created = space.gmt_created
|
res.gmt_created = space.gmt_created
|
||||||
res.gmt_modified = space.gmt_modified
|
res.gmt_modified = space.gmt_modified
|
||||||
res.owner = space.owner
|
res.owner = space.owner
|
||||||
|
res.context = space.context
|
||||||
query = KnowledgeDocumentEntity(space=space.name)
|
query = KnowledgeDocumentEntity(space=space.name)
|
||||||
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
|
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||||
res.docs = doc_count
|
res.docs = doc_count
|
||||||
responses.append(res)
|
responses.append(res)
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
def arguments(self, space_name):
|
||||||
|
query = KnowledgeSpaceEntity(name=space_name)
|
||||||
|
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||||
|
if len(spaces) != 1:
|
||||||
|
raise Exception(f"there are no or more than one space called {space_name}")
|
||||||
|
space = spaces[0]
|
||||||
|
if space.context is None:
|
||||||
|
context = self._build_default_context()
|
||||||
|
else:
|
||||||
|
context = space.context
|
||||||
|
return json.loads(context)
|
||||||
|
|
||||||
|
def argument_save(self, space_name, argument_request: SpaceArgumentRequest):
|
||||||
|
query = KnowledgeSpaceEntity(name=space_name)
|
||||||
|
spaces = knowledge_space_dao.get_knowledge_space(query)
|
||||||
|
if len(spaces) != 1:
|
||||||
|
raise Exception(f"there are no or more than one space called {space_name}")
|
||||||
|
space = spaces[0]
|
||||||
|
space.context = argument_request.argument
|
||||||
|
return knowledge_space_dao.update_knowledge_space(space)
|
||||||
|
|
||||||
"""get knowledge get_knowledge_documents"""
|
"""get knowledge get_knowledge_documents"""
|
||||||
|
|
||||||
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
|
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
|
||||||
@ -142,22 +166,34 @@ class KnowledgeService:
|
|||||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
space_context = self.get_space_context(space_name)
|
||||||
|
chunk_size = (
|
||||||
|
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
|
if space_context is None
|
||||||
|
else int(space_context["embedding"]["chunk_size"])
|
||||||
|
)
|
||||||
|
chunk_overlap = (
|
||||||
|
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||||
|
if space_context is None
|
||||||
|
else int(space_context["embedding"]["chunk_overlap"])
|
||||||
|
)
|
||||||
if CFG.LANGUAGE == "en":
|
if CFG.LANGUAGE == "en":
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=20,
|
chunk_overlap=chunk_overlap,
|
||||||
length_function=len,
|
length_function=len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
text_splitter = SpacyTextSplitter(
|
text_splitter = SpacyTextSplitter(
|
||||||
pipeline="zh_core_web_sm",
|
pipeline="zh_core_web_sm",
|
||||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=100,
|
chunk_overlap=chunk_overlap
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
)
|
)
|
||||||
client = EmbeddingEngine(
|
client = EmbeddingEngine(
|
||||||
knowledge_source=doc.content,
|
knowledge_source=doc.content,
|
||||||
@ -287,3 +323,40 @@ class KnowledgeService:
|
|||||||
doc.result = "document embedding failed" + str(e)
|
doc.result = "document embedding failed" + str(e)
|
||||||
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||||
return knowledge_document_dao.update_knowledge_document(doc)
|
return knowledge_document_dao.update_knowledge_document(doc)
|
||||||
|
|
||||||
|
def _build_default_context(self):
|
||||||
|
from pilot.scene.chat_knowledge.v1.prompt import (
|
||||||
|
PROMPT_SCENE_DEFINE,
|
||||||
|
_DEFAULT_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
context_template = {
|
||||||
|
"embedding": {
|
||||||
|
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||||
|
"recall_score": 0.0,
|
||||||
|
"recall_type": "TopK",
|
||||||
|
"model": CFG.EMBEDDING_MODEL,
|
||||||
|
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"max_token": 2000,
|
||||||
|
"scene": PROMPT_SCENE_DEFINE,
|
||||||
|
"template": _DEFAULT_TEMPLATE,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
context_template_string = json.dumps(context_template, indent=4)
|
||||||
|
return context_template_string
|
||||||
|
|
||||||
|
def get_space_context(self, space_name):
|
||||||
|
request = KnowledgeSpaceRequest()
|
||||||
|
request.name = space_name
|
||||||
|
spaces = self.get_knowledge_space(request)
|
||||||
|
if len(spaces) != 1:
|
||||||
|
raise Exception(
|
||||||
|
f"have not found {space_name} space or found more than one space called {space_name}"
|
||||||
|
)
|
||||||
|
space = spaces[0]
|
||||||
|
if space.context is not None:
|
||||||
|
return json.loads(spaces[0].context)
|
||||||
|
return None
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, create_engine
|
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -19,11 +19,12 @@ class KnowledgeSpaceEntity(Base):
|
|||||||
vector_type = Column(String(100))
|
vector_type = Column(String(100))
|
||||||
desc = Column(String(100))
|
desc = Column(String(100))
|
||||||
owner = Column(String(100))
|
owner = Column(String(100))
|
||||||
|
context = Column(Text)
|
||||||
gmt_created = Column(DateTime)
|
gmt_created = Column(DateTime)
|
||||||
gmt_modified = Column(DateTime)
|
gmt_modified = Column(DateTime)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSpaceDao:
|
class KnowledgeSpaceDao:
|
||||||
@ -88,14 +89,12 @@ class KnowledgeSpaceDao:
|
|||||||
session.close()
|
session.close()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
|
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||||
cursor = self.conn.cursor()
|
session = self.Session()
|
||||||
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
|
session.merge(space)
|
||||||
cursor.execute(
|
session.commit()
|
||||||
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
|
session.close()
|
||||||
)
|
return True
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
@ -20,13 +21,15 @@ class ChromaStore(VectorStoreBase):
|
|||||||
persist_directory=self.persist_dir,
|
persist_directory=self.persist_dir,
|
||||||
anonymized_telemetry=False,
|
anonymized_telemetry=False,
|
||||||
)
|
)
|
||||||
|
collection_metadata = {"hnsw:space": "cosine"}
|
||||||
self.vector_store_client = Chroma(
|
self.vector_store_client = Chroma(
|
||||||
persist_directory=self.persist_dir,
|
persist_directory=self.persist_dir,
|
||||||
embedding_function=self.embeddings,
|
embedding_function=self.embeddings,
|
||||||
client_settings=chroma_settings,
|
client_settings=chroma_settings,
|
||||||
|
collection_metadata=collection_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(self, text, topk) -> None:
|
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||||
logger.info("ChromaStore similar search")
|
logger.info("ChromaStore similar search")
|
||||||
return self.vector_store_client.similarity_search(text, topk)
|
return self.vector_store_client.similarity_search(text, topk)
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from pilot.vector_store.chroma_store import ChromaStore
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
from pilot.vector_store.milvus_store import MilvusStore
|
||||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
|
||||||
|
|
||||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
|
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConnector:
|
class VectorStoreConnector:
|
||||||
|
Loading…
Reference in New Issue
Block a user