From 929e7fe96bdd9eb3072d11c30e9bccdb16fc394a Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Wed, 12 Jul 2023 11:07:35 +0800 Subject: [PATCH] refactor:refactor knowledge api 1.delete CFG in embedding_engine api 2.add a text_splitter param in embedding_engine api 3.fmt --- pilot/embedding_engine/csv_embedding.py | 8 +++++++- pilot/embedding_engine/markdown_embedding.py | 10 ++++++++-- pilot/embedding_engine/pdf_embedding.py | 15 ++++++++++++--- pilot/embedding_engine/ppt_embedding.py | 15 ++++++++++++--- pilot/embedding_engine/source_embedding.py | 4 ++-- pilot/embedding_engine/string_embedding.py | 12 +++++++++--- pilot/embedding_engine/url_embedding.py | 16 ++++++++++++---- pilot/embedding_engine/word_embedding.py | 15 ++++++++++++--- pilot/vector_store/milvus_store.py | 2 -- .../unit/embedding_engine/test_url_embedding.py | 16 ++++++++++------ tools/knowledge_init.py | 9 +++++++-- 11 files changed, 91 insertions(+), 31 deletions(-) diff --git a/pilot/embedding_engine/csv_embedding.py b/pilot/embedding_engine/csv_embedding.py index ad2ca4333..9ba28459b 100644 --- a/pilot/embedding_engine/csv_embedding.py +++ b/pilot/embedding_engine/csv_embedding.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional from langchain.document_loaders import CSVLoader from langchain.schema import Document +from langchain.text_splitter import TextSplitter from pilot.embedding_engine import SourceEmbedding, register @@ -9,7 +10,12 @@ from pilot.embedding_engine import SourceEmbedding, register class CSVEmbedding(SourceEmbedding): """csv embedding for read csv document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize with csv path.""" super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path diff --git a/pilot/embedding_engine/markdown_embedding.py b/pilot/embedding_engine/markdown_embedding.py index 03969a925..fa2ddc914 100644 --- a/pilot/embedding_engine/markdown_embedding.py +++ b/pilot/embedding_engine/markdown_embedding.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- import os -from typing import List +from typing import List, Optional import markdown from bs4 import BeautifulSoup @@ -10,6 +10,7 @@ from langchain.text_splitter import ( SpacyTextSplitter, CharacterTextSplitter, RecursiveCharacterTextSplitter, + TextSplitter, ) from pilot.embedding_engine import SourceEmbedding, register @@ -19,7 +20,12 @@ from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader class MarkdownEmbedding(SourceEmbedding): """markdown embedding for read markdown document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize raw text word path.""" super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path diff --git a/pilot/embedding_engine/pdf_embedding.py b/pilot/embedding_engine/pdf_embedding.py index 5928ade43..cbe68da1b 100644 --- a/pilot/embedding_engine/pdf_embedding.py +++ b/pilot/embedding_engine/pdf_embedding.py @@ -1,10 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import List +from typing import List, Optional from langchain.document_loaders import PyPDFLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter +from langchain.text_splitter import ( + SpacyTextSplitter, + RecursiveCharacterTextSplitter, + TextSplitter, +) from pilot.embedding_engine import SourceEmbedding, register @@ -12,7 +16,12 @@ from pilot.embedding_engine import SourceEmbedding, register class PDFEmbedding(SourceEmbedding): """pdf embedding for read pdf document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize pdf word path.""" super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 7dd6f057f..59de18392 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -1,10 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import List +from typing import List, Optional from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter +from langchain.text_splitter import ( + SpacyTextSplitter, + RecursiveCharacterTextSplitter, + TextSplitter, +) from pilot.embedding_engine import SourceEmbedding, register @@ -12,7 +16,12 @@ from pilot.embedding_engine import SourceEmbedding, register class PPTEmbedding(SourceEmbedding): """ppt embedding for read ppt document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize ppt word path.""" super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path diff --git a/pilot/embedding_engine/source_embedding.py b/pilot/embedding_engine/source_embedding.py index 6eb5b2265..c1ceabed1 100644 --- a/pilot/embedding_engine/source_embedding.py +++ b/pilot/embedding_engine/source_embedding.py @@ -26,13 +26,13 @@ class SourceEmbedding(ABC): self, file_path, vector_store_config: {}, - text_splitter: TextSplitter = None, + text_splitter: Optional[TextSplitter] = None, embedding_args: Optional[Dict] = None, ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path self.vector_store_config = vector_store_config - self.text_splitter = text_splitter + self.text_splitter = text_splitter or None self.embedding_args = embedding_args self.embeddings = vector_store_config["embeddings"] diff --git a/pilot/embedding_engine/string_embedding.py b/pilot/embedding_engine/string_embedding.py index f6506c153..64d81899b 100644 --- a/pilot/embedding_engine/string_embedding.py +++ b/pilot/embedding_engine/string_embedding.py @@ -1,6 +1,7 @@ -from typing import List +from typing import List, Optional from langchain.schema import Document +from langchain.text_splitter import TextSplitter from pilot.embedding_engine import SourceEmbedding, register @@ -8,9 +9,14 @@ from pilot.embedding_engine import SourceEmbedding, register class StringEmbedding(SourceEmbedding): """string embedding for read string document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize raw text word path.""" - super().__init__(file_path, vector_store_config, text_splitter=None) + super().__init__(file_path=file_path, vector_store_config=vector_store_config) self.file_path = file_path self.vector_store_config = vector_store_config self.text_splitter = text_splitter or None diff --git a/pilot/embedding_engine/url_embedding.py b/pilot/embedding_engine/url_embedding.py index 8c79b05c5..e748d2d59 100644 --- a/pilot/embedding_engine/url_embedding.py +++ b/pilot/embedding_engine/url_embedding.py @@ -1,18 +1,26 @@ -from typing import List +from typing import List, Optional from bs4 import BeautifulSoup from langchain.document_loaders import WebBaseLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter +from langchain.text_splitter import ( + SpacyTextSplitter, + RecursiveCharacterTextSplitter, + TextSplitter, +) from pilot.embedding_engine import SourceEmbedding, register - class URLEmbedding(SourceEmbedding): """url embedding for read url document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize url word path.""" super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path diff --git a/pilot/embedding_engine/word_embedding.py b/pilot/embedding_engine/word_embedding.py index d7995bbf3..98bebec3a 100644 --- a/pilot/embedding_engine/word_embedding.py +++ b/pilot/embedding_engine/word_embedding.py @@ -1,10 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import List +from typing import List, Optional from langchain.document_loaders import UnstructuredWordDocumentLoader from langchain.schema import Document -from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter +from langchain.text_splitter import ( + SpacyTextSplitter, + RecursiveCharacterTextSplitter, + TextSplitter, +) from pilot.embedding_engine import SourceEmbedding, register @@ -12,7 +16,12 @@ from pilot.embedding_engine import SourceEmbedding, register class WordEmbedding(SourceEmbedding): """word embedding for read word document.""" - def __init__(self, file_path, vector_store_config, text_splitter=None): + def __init__( + self, + file_path, + vector_store_config, + text_splitter: Optional[TextSplitter] = None, + ): """Initialize with word path.""" super().__init__(file_path, vector_store_config, text_splitter=None) self.file_path = file_path diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 1230bede9..60192873c 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -6,8 +6,6 @@ from pymilvus import Collection, DataType, connections, utility from pilot.vector_store.vector_store_base import VectorStoreBase - - class MilvusStore(VectorStoreBase): """Milvus database""" diff --git a/tests/unit/embedding_engine/test_url_embedding.py b/tests/unit/embedding_engine/test_url_embedding.py index 4cd7dcbd8..30f2a36cb 100644 --- a/tests/unit/embedding_engine/test_url_embedding.py +++ b/tests/unit/embedding_engine/test_url_embedding.py @@ -5,12 +5,16 @@ embedding_model = "text2vec" vector_store_type = "Chroma" chroma_persist_path = "your_persist_path" vector_store_config = { - "vector_store_name": url.replace(":", ""), - "vector_store_type": vector_store_type, - "chroma_persist_path": chroma_persist_path - } -embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config) + "vector_store_name": url.replace(":", ""), + "vector_store_type": vector_store_type, + "chroma_persist_path": chroma_persist_path, +} +embedding_engine = EmbeddingEngine( + knowledge_source=url, + knowledge_type=KnowledgeType.URL.value, + model_name=embedding_model, + vector_store_config=vector_store_config, +) # embedding url content to vector store embedding_engine.knowledge_embedding() - diff --git a/tools/knowledge_init.py b/tools/knowledge_init.py index 8ccd567c4..c442de8c9 100644 --- a/tools/knowledge_init.py +++ b/tools/knowledge_init.py @@ -14,7 +14,8 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, - LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + KNOWLEDGE_UPLOAD_ROOT_PATH, ) from pilot.embedding_engine.embedding_engine import EmbeddingEngine @@ -68,7 +69,11 @@ if __name__ == "__main__": args = parser.parse_args() vector_name = args.vector_name store_type = CFG.VECTOR_STORE_TYPE - vector_store_config = {"vector_store_name": vector_name, "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH} + vector_store_config = { + "vector_store_name": vector_name, + "vector_store_type": CFG.VECTOR_STORE_TYPE, + "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) kv.knowledge_persist(file_path=DATASETS_DIR)