feat:add knowledge space argument

This commit is contained in:
aries_ckt 2023-08-07 16:36:41 +08:00
parent cc57ed22ca
commit aa3c3205a4
9 changed files with 146 additions and 21 deletions

View File

@ -163,7 +163,9 @@ class Config(metaclass=Singleton):
### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 100))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
self.KNOWLEDGE_SEARCH_MAX_TOKEN = int(os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000))
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

View File

@ -20,6 +20,7 @@ from pilot.configs.model_config import (
from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.server.knowledge.service import KnowledgeService
CFG = Config()
@ -36,6 +37,18 @@ class ChatKnowledge(BaseChat):
chat_session_id=chat_session_id,
current_user_input=user_input,
)
self.space_context = self.get_space_context(knowledge_space)
self.top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if self.space_context is None
else int(self.space_context["embedding"]["topk"])
)
# self.recall_score = CFG.KNOWLEDGE_SEARCH_TOP_SIZE if self.space_context is None else self.space_context["embedding"]["recall_score"]
self.max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if self.space_context is None
else int(self.space_context["prompt"]["max_token"])
)
vector_store_config = {
"vector_store_name": knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
@ -48,11 +61,14 @@ class ChatKnowledge(BaseChat):
def generate_input_values(self):
try:
if self.space_context:
prompt.template_define = self.space_context["prompt"]["scene"]
prompt.template = self.space_context["prompt"]["template"]
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
self.current_user_input, self.top_k
)
context = [d.page_content for d in docs]
context = context[:2000]
context = context[: self.max_token]
input_values = {"context": context, "question": self.current_user_input}
except NoIndexException:
raise ValueError(
@ -63,3 +79,7 @@ class ChatKnowledge(BaseChat):
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value()
def get_space_context(self, space_name):
service = KnowledgeService()
return service.get_space_context(space_name)

View File

@ -20,6 +20,7 @@ from pilot.server.knowledge.request.request import (
DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
SpaceArgumentRequest,
)
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
@ -54,13 +55,33 @@ def space_list(request: KnowledgeSpaceRequest):
@router.post("/knowledge/space/delete")
def space_delete(request: KnowledgeSpaceRequest):
print(f"/space/list params:")
print(f"/space/delete params:")
try:
return Result.succ(knowledge_space_service.delete_space(request.name))
except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")
@router.post("/knowledge/{space_name}/arguments")
def arguments(space_name: str):
print(f"/knowledge/space/arguments params:")
try:
return Result.succ(knowledge_space_service.arguments(space_name))
except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")
@router.post("/knowledge/{space_name}/argument/save")
def arguments_save(space_name: str, argument_request: SpaceArgumentRequest):
print(f"/knowledge/space/argument/save params:")
try:
return Result.succ(
knowledge_space_service.argument_save(space_name, argument_request)
)
except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")
@router.post("/knowledge/{space_name}/document/add")
def document_add(space_name: str, request: KnowledgeDocumentRequest):
print(f"/document/add params: {space_name}, {request}")

View File

@ -83,3 +83,9 @@ class KnowledgeQueryResponse:
score: float = 0.0
"""text: raw text info"""
text: str
class SpaceArgumentRequest(BaseModel):
"""argument: argument"""
argument: str

View File

@ -32,6 +32,8 @@ class SpaceQueryResponse(BaseModel):
vector_type: str = None
"""desc: description"""
desc: str = None
"""context: context"""
context: str = None
"""owner: owner"""
owner: str = None
gmt_created: str = None

View File

@ -1,3 +1,4 @@
import json
import threading
from datetime import datetime
@ -25,6 +26,7 @@ from pilot.server.knowledge.request.request import (
KnowledgeDocumentRequest,
DocumentQueryRequest,
ChunkQueryRequest,
SpaceArgumentRequest,
)
from enum import Enum
@ -102,12 +104,34 @@ class KnowledgeService:
res.gmt_created = space.gmt_created
res.gmt_modified = space.gmt_modified
res.owner = space.owner
res.context = space.context
query = KnowledgeDocumentEntity(space=space.name)
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
res.docs = doc_count
responses.append(res)
return responses
def arguments(self, space_name):
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) != 1:
raise Exception(f"there are no or more than one space called {space_name}")
space = spaces[0]
if space.context is None:
context = self._build_default_context()
else:
context = space.context
return json.loads(context)
def argument_save(self, space_name, argument_request: SpaceArgumentRequest):
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) != 1:
raise Exception(f"there are no or more than one space called {space_name}")
space = spaces[0]
space.context = argument_request.argument
return knowledge_space_dao.update_knowledge_space(space)
"""get knowledge get_knowledge_documents"""
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
@ -142,22 +166,34 @@ class KnowledgeService:
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
space_context = self.get_space_context(space_name)
chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
else:
try:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
client = EmbeddingEngine(
knowledge_source=doc.content,
@ -287,3 +323,40 @@ class KnowledgeService:
doc.result = "document embedding failed" + str(e)
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
def _build_default_context(self):
from pilot.scene.chat_knowledge.v1.prompt import (
PROMPT_SCENE_DEFINE,
_DEFAULT_TEMPLATE,
)
context_template = {
"embedding": {
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
"recall_score": 0.0,
"recall_type": "TopK",
"model": CFG.EMBEDDING_MODEL,
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
},
"prompt": {
"max_token": 2000,
"scene": PROMPT_SCENE_DEFINE,
"template": _DEFAULT_TEMPLATE,
},
}
context_template_string = json.dumps(context_template, indent=4)
return context_template_string
def get_space_context(self, space_name):
request = KnowledgeSpaceRequest()
request.name = space_name
spaces = self.get_knowledge_space(request)
if len(spaces) != 1:
raise Exception(
f"have not found {space_name} space or found more than one space called {space_name}"
)
space = spaces[0]
if space.context is not None:
return json.loads(spaces[0].context)
return None

View File

@ -1,6 +1,6 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, create_engine
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config
@ -19,11 +19,12 @@ class KnowledgeSpaceEntity(Base):
vector_type = Column(String(100))
desc = Column(String(100))
owner = Column(String(100))
context = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}' context='{self.context}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeSpaceDao:
@ -88,14 +89,12 @@ class KnowledgeSpaceDao:
session.close()
return result
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
cursor.execute(
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
)
self.conn.commit()
cursor.close()
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session()
session.merge(space)
session.commit()
session.close()
return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session()

View File

@ -1,4 +1,5 @@
import os
from typing import Any
from chromadb.config import Settings
from langchain.vectorstores import Chroma
@ -20,13 +21,15 @@ class ChromaStore(VectorStoreBase):
persist_directory=self.persist_dir,
anonymized_telemetry=False,
)
collection_metadata = {"hnsw:space": "cosine"}
self.vector_store_client = Chroma(
persist_directory=self.persist_dir,
embedding_function=self.embeddings,
client_settings=chroma_settings,
collection_metadata=collection_metadata
)
def similar_search(self, text, topk) -> None:
def similar_search(self, text, topk, **kwargs: Any) -> None:
logger.info("ChromaStore similar search")
return self.vector_store_client.similarity_search(text, topk)

View File

@ -1,9 +1,8 @@
from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore
from pilot.vector_store.weaviate_store import WeaviateStore
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
class VectorStoreConnector: