style:fmt

This commit is contained in:
aries_ckt 2023-10-19 20:16:38 +08:00
parent 01eae42554
commit d11ec46ee5
9 changed files with 131 additions and 31 deletions

View File

@ -21,7 +21,13 @@ class CSVEmbedding(SourceEmbedding):
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize with csv path."""
"""Initialize with csv path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)

View File

@ -28,7 +28,16 @@ class EmbeddingEngine:
text_splitter: Optional[TextSplitter] = None,
embedding_factory: EmbeddingFactory = None,
):
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source
Args:
- model_name: model_name
- vector_store_config: vector store config: Dict
- knowledge_type: Optional[KnowledgeType]
- knowledge_source: Optional[str]
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
- embedding_factory: EmbeddingFactory
"""
self.knowledge_source = knowledge_source
self.model_name = model_name
self.vector_store_config = vector_store_config
@ -65,6 +74,11 @@ class EmbeddingEngine:
)
def similar_search(self, text, topk):
"""vector db similar search
Args:
- text: query text
- topk: top k
"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
@ -75,12 +89,17 @@ class EmbeddingEngine:
return ans
def vector_exist(self):
"""vector db is exist"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
return vector_client.vector_name_exists()
def delete_by_ids(self, ids):
"""delete vector db by ids
Args:
- ids: vector ids
"""
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)

View File

@ -23,7 +23,13 @@ class PDFEmbedding(SourceEmbedding):
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize pdf word path."""
"""Initialize pdf word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)

View File

@ -23,7 +23,13 @@ class PPTEmbedding(SourceEmbedding):
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize ppt word path."""
"""Initialize ppt word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)

View File

@ -33,7 +33,7 @@ class SourceEmbedding(ABC):
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[]
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
- embedding_args: Optional
"""
@ -52,8 +52,8 @@ class SourceEmbedding(ABC):
@register
def data_process(self, text):
"""pre process data.
Args:
- text: raw text
Args:
- text: raw text
"""
@register

View File

@ -20,7 +20,13 @@ class StringEmbedding(SourceEmbedding):
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize raw text word path."""
"""Initialize raw text word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path=file_path,
vector_store_config=vector_store_config,

View File

@ -22,7 +22,13 @@ class URLEmbedding(SourceEmbedding):
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize url word path."""
"""Initialize url word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)

View File

@ -23,7 +23,13 @@ class WordEmbedding(SourceEmbedding):
source_reader: Optional = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize with word path."""
"""Initialize with word path.
Args:
- file_path: data source path
- vector_store_config: vector store config params.
- source_reader: Optional[BaseLoader]
- text_splitter: Optional[TextSplitter]
"""
super().__init__(
file_path, vector_store_config, source_reader=None, text_splitter=None
)

View File

@ -57,12 +57,21 @@ class SyncStatus(Enum):
# @singleton
class KnowledgeService:
"""KnowledgeService
Knowledge Management Service:
-knowledge_space management
-knowledge_document management
-embedding management
"""
def __init__(self):
pass
"""create knowledge space"""
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
"""create knowledge space
Args:
- request: KnowledgeSpaceRequest
"""
query = KnowledgeSpaceEntity(
name=request.name,
)
@ -72,9 +81,11 @@ class KnowledgeService:
knowledge_space_dao.create_knowledge_space(request)
return True
"""create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
"""create knowledge document
Args:
- request: KnowledgeDocumentRequest
"""
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0:
@ -91,9 +102,11 @@ class KnowledgeService:
)
return knowledge_document_dao.create_knowledge_document(document)
"""get knowledge space"""
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
"""get knowledge space
Args:
- request: KnowledgeSpaceRequest
"""
query = KnowledgeSpaceEntity(
name=request.name, vector_type=request.vector_type, owner=request.owner
)
@ -116,6 +129,10 @@ class KnowledgeService:
return responses
def arguments(self, space_name):
"""show knowledge space arguments
Args:
- space_name: Knowledge Space Name
"""
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) != 1:
@ -128,6 +145,11 @@ class KnowledgeService:
return json.loads(context)
def argument_save(self, space_name, argument_request: SpaceArgumentRequest):
"""save argument
Args:
- space_name: Knowledge Space Name
- argument_request: SpaceArgumentRequest
"""
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) != 1:
@ -136,9 +158,12 @@ class KnowledgeService:
space.context = argument_request.argument
return knowledge_space_dao.update_knowledge_space(space)
"""get knowledge get_knowledge_documents"""
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
"""get knowledge documents
Args:
- space: Knowledge Space Name
- request: DocumentQueryRequest
"""
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
@ -153,9 +178,12 @@ class KnowledgeService:
res.page = request.page
return res
"""sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
"""sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_request: DocumentSyncRequest
"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
@ -249,11 +277,6 @@ class KnowledgeService:
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
# thread = threading.Thread(
# target=self.async_doc_embedding, args=(client, chunk_docs, doc)
# )
# thread.start()
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
@ -277,16 +300,21 @@ class KnowledgeService:
return True
"""update knowledge space"""
def update_knowledge_space(
self, space_id: int, space_request: KnowledgeSpaceRequest
):
"""update knowledge space
Args:
- space_id: space id
- space_request: KnowledgeSpaceRequest
"""
knowledge_space_dao.update_knowledge_space(space_id, space_request)
"""delete knowledge space"""
def delete_space(self, space_name: str):
"""delete knowledge space
Args:
- space_name: knowledge space name
"""
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) == 0:
@ -312,6 +340,11 @@ class KnowledgeService:
return knowledge_space_dao.delete_knowledge_space(space)
def delete_document(self, space_name: str, doc_name: str):
"""delete document
Args:
- space_name: knowledge space name
- doc_name: doocument name
"""
document_query = KnowledgeDocumentEntity(doc_name=doc_name, space=space_name)
documents = knowledge_document_dao.get_documents(document_query)
if len(documents) != 1:
@ -332,9 +365,11 @@ class KnowledgeService:
# delete document
return knowledge_document_dao.delete(document_query)
"""get document chunks"""
def get_document_chunks(self, request: ChunkQueryRequest):
"""get document chunks
Args:
- request: ChunkQueryRequest
"""
query = DocumentChunkEntity(
id=request.id,
document_id=request.document_id,
@ -350,6 +385,12 @@ class KnowledgeService:
return res
def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db
Args:
- client: EmbeddingEngine Client
- chunk_docs: List[Document]
- doc: doc
"""
logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
@ -391,6 +432,10 @@ class KnowledgeService:
return context_template_string
def get_space_context(self, space_name):
"""get space contect
Args:
- space_name: space name
"""
request = KnowledgeSpaceRequest()
request.name = space_name
spaces = self.get_knowledge_space(request)