diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 4d695da7d..8c03769b4 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -163,7 +163,9 @@ class Config(metaclass=Singleton): ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100)) + self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 100)) self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5)) + self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)) ### SUMMARY_CONFIG Configuration self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST") diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index e075634de..ef6b281c7 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -20,6 +20,7 @@ from pilot.configs.model_config import ( from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.embedding_engine.embedding_engine import EmbeddingEngine +from pilot.server.knowledge.service import KnowledgeService CFG = Config() @@ -36,6 +37,18 @@ class ChatKnowledge(BaseChat): chat_session_id=chat_session_id, current_user_input=user_input, ) + self.space_context = self.get_space_context(knowledge_space) + self.top_k = ( + CFG.KNOWLEDGE_SEARCH_TOP_SIZE + if self.space_context is None + else int(self.space_context["embedding"]["topk"]) + ) + # self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"] + self.max_token = ( + CFG.KNOWLEDGE_SEARCH_MAX_TOKEN + if self.space_context is None + else int(self.space_context["prompt"]["max_token"]) + ) vector_store_config = { "vector_store_name": knowledge_space, "vector_store_type": CFG.VECTOR_STORE_TYPE, @@ -48,11 +61,14 @@ class ChatKnowledge(BaseChat): def generate_input_values(self): try: + if self.space_context: + prompt.template_define = self.space_context["prompt"]["scene"] + prompt.template = self.space_context["prompt"]["template"] docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE + self.current_user_input, self.top_k ) context = [d.page_content for d in docs] - context = context[:2000] + context = context[: self.max_token] input_values = {"context": context, "question": self.current_user_input} except NoIndexException: raise ValueError( @@ -63,3 +79,7 @@ class ChatKnowledge(BaseChat): @property def chat_type(self) -> str: return ChatScene.ChatKnowledge.value() + + def get_space_context(self, space_name): + service = KnowledgeService() + return service.get_space_context(space_name) diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 737b779bc..b6aef1ab3 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -20,6 +20,7 @@ from pilot.server.knowledge.request.request import ( DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, + SpaceArgumentRequest, ) from pilot.server.knowledge.request.request import KnowledgeSpaceRequest @@ -54,13 +55,33 @@ def space_list(request: KnowledgeSpaceRequest): @router.post("/knowledge/space/delete") def space_delete(request: KnowledgeSpaceRequest): - print(f"/space/list params:") + print(f"/space/delete params:") try: return Result.succ(knowledge_space_service.delete_space(request.name)) except Exception as e: return Result.faild(code="E000X", msg=f"space list error {e}") +@router.post("/knowledge/{space_name}/arguments") +def arguments(space_name: str): + print(f"/knowledge/space/arguments params:") + try: + return Result.succ(knowledge_space_service.arguments(space_name)) + except Exception as e: + return Result.faild(code="E000X", msg=f"space list error {e}") + + +@router.post("/knowledge/{space_name}/argument/save") +def arguments_save(space_name: str, argument_request: SpaceArgumentRequest): + print(f"/knowledge/space/argument/save params:") + try: + return Result.succ( + knowledge_space_service.argument_save(space_name, argument_request) + ) + except Exception as e: + return Result.faild(code="E000X", msg=f"space list error {e}") + + @router.post("/knowledge/{space_name}/document/add") def document_add(space_name: str, request: KnowledgeDocumentRequest): print(f"/document/add params: {space_name}, {request}") diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py index d393ca9b7..f0c47abeb 100644 --- a/pilot/server/knowledge/request/request.py +++ b/pilot/server/knowledge/request/request.py @@ -83,3 +83,9 @@ class KnowledgeQueryResponse: score: float = 0.0 """text: raw text info""" text: str + + +class SpaceArgumentRequest(BaseModel): + """argument: argument""" + + argument: str diff --git a/pilot/server/knowledge/request/response.py b/pilot/server/knowledge/request/response.py index d302eb392..fb7aa55e9 100644 --- a/pilot/server/knowledge/request/response.py +++ b/pilot/server/knowledge/request/response.py @@ -32,6 +32,8 @@ class SpaceQueryResponse(BaseModel): vector_type: str = None """desc: description""" desc: str = None + """context: context""" + context: str = None """owner: owner""" owner: str = None gmt_created: str = None diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 3f863e0c4..60745a118 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -1,3 +1,4 @@ +import json import threading from datetime import datetime @@ -25,6 +26,7 @@ from pilot.server.knowledge.request.request import ( KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, + SpaceArgumentRequest, ) from enum import Enum @@ -102,12 +104,34 @@ class KnowledgeService: res.gmt_created = space.gmt_created res.gmt_modified = space.gmt_modified res.owner = space.owner + res.context = space.context query = KnowledgeDocumentEntity(space=space.name) doc_count = knowledge_document_dao.get_knowledge_documents_count(query) res.docs = doc_count responses.append(res) return responses + def arguments(self, space_name): + query = KnowledgeSpaceEntity(name=space_name) + spaces = knowledge_space_dao.get_knowledge_space(query) + if len(spaces) != 1: + raise Exception(f"there are no or more than one space called {space_name}") + space = spaces[0] + if space.context is None: + context = self._build_default_context() + else: + context = space.context + return json.loads(context) + + def argument_save(self, space_name, argument_request: SpaceArgumentRequest): + query = KnowledgeSpaceEntity(name=space_name) + spaces = knowledge_space_dao.get_knowledge_space(query) + if len(spaces) != 1: + raise Exception(f"there are no or more than one space called {space_name}") + space = spaces[0] + space.context = argument_request.argument + return knowledge_space_dao.update_knowledge_space(space) + """get knowledge get_knowledge_documents""" def get_knowledge_documents(self, space, request: DocumentQueryRequest): @@ -142,22 +166,34 @@ class KnowledgeService: f" doc:{doc.doc_name} status is {doc.status}, can not sync" ) + space_context = self.get_space_context(space_name) + chunk_size = ( + CFG.KNOWLEDGE_CHUNK_SIZE + if space_context is None + else int(space_context["embedding"]["chunk_size"]) + ) + chunk_overlap = ( + CFG.KNOWLEDGE_CHUNK_OVERLAP + if space_context is None + else int(space_context["embedding"]["chunk_overlap"]) + ) if CFG.LANGUAGE == "en": text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=20, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, length_function=len, ) else: try: text_splitter = SpacyTextSplitter( pipeline="zh_core_web_sm", - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, - chunk_overlap=100, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap ) except Exception: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, ) client = EmbeddingEngine( knowledge_source=doc.content, @@ -287,3 +323,40 @@ class KnowledgeService: doc.result = "document embedding failed" + str(e) logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}") return knowledge_document_dao.update_knowledge_document(doc) + + def _build_default_context(self): + from pilot.scene.chat_knowledge.v1.prompt import ( + PROMPT_SCENE_DEFINE, + _DEFAULT_TEMPLATE, + ) + + context_template = { + "embedding": { + "topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE, + "recall_score": 0.0, + "recall_type": "TopK", + "model": CFG.EMBEDDING_MODEL, + "chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE, + "chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP, + }, + "prompt": { + "max_token": 2000, + "scene": PROMPT_SCENE_DEFINE, + "template": _DEFAULT_TEMPLATE, + }, + } + context_template_string = json.dumps(context_template, indent=4) + return context_template_string + + def get_space_context(self, space_name): + request = KnowledgeSpaceRequest() + request.name = space_name + spaces = self.get_knowledge_space(request) + if len(spaces) != 1: + raise Exception( + f"have not found {space_name} space or found more than one space called {space_name}" + ) + space = spaces[0] + if space.context is not None: + return json.loads(spaces[0].context) + return None diff --git a/pilot/server/knowledge/space_db.py b/pilot/server/knowledge/space_db.py index 911683fdd..57a909e06 100644 --- a/pilot/server/knowledge/space_db.py +++ b/pilot/server/knowledge/space_db.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import Column, Integer, String, DateTime, create_engine +from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine from sqlalchemy.ext.declarative import declarative_base from pilot.configs.config import Config @@ -19,11 +19,12 @@ class KnowledgeSpaceEntity(Base): vector_type = Column(String(100)) desc = Column(String(100)) owner = Column(String(100)) + context = Column(Text) gmt_created = Column(DateTime) gmt_modified = Column(DateTime) def __repr__(self): - return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" class KnowledgeSpaceDao: @@ -88,14 +89,12 @@ class KnowledgeSpaceDao: session.close() return result - def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity): - cursor = self.conn.cursor() - query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s" - cursor.execute( - query, (space.name, space.vector_type, space.desc, space.owner, space_id) - ) - self.conn.commit() - cursor.close() + def update_knowledge_space(self, space: KnowledgeSpaceEntity): + session = self.Session() + session.merge(space) + session.commit() + session.close() + return True def delete_knowledge_space(self, space: KnowledgeSpaceEntity): session = self.Session() diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py index e47a11761..451d9952c 100644 --- a/pilot/vector_store/chroma_store.py +++ b/pilot/vector_store/chroma_store.py @@ -1,4 +1,5 @@ import os +from typing import Any from chromadb.config import Settings from langchain.vectorstores import Chroma @@ -20,13 +21,15 @@ class ChromaStore(VectorStoreBase): persist_directory=self.persist_dir, anonymized_telemetry=False, ) + collection_metadata = {"hnsw:space": "cosine"} self.vector_store_client = Chroma( persist_directory=self.persist_dir, embedding_function=self.embeddings, client_settings=chroma_settings, + collection_metadata=collection_metadata ) - def similar_search(self, text, topk) -> None: + def similar_search(self, text, topk, **kwargs: Any) -> None: logger.info("ChromaStore similar search") return self.vector_store_client.similarity_search(text, topk) diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index ca56986c8..a1fce8360 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,9 +1,8 @@ from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.milvus_store import MilvusStore -from pilot.vector_store.weaviate_store import WeaviateStore -connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore} +connector = {"Chroma": ChromaStore, "Milvus": MilvusStore} class VectorStoreConnector: