feat:embedding_engine add text_splitter param

This commit is contained in:
aries_ckt 2023-07-12 18:01:22 +08:00
parent f911a8fa97
commit ff89e2e085
6 changed files with 53 additions and 51 deletions

View File

@ -19,6 +19,7 @@ you will prepare embedding models from huggingface
Notice make sure you have install git-lfs Notice make sure you have install git-lfs
```{tip} ```{tip}
git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
``` ```
version: version:

View File

@ -72,6 +72,24 @@ eg: git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
vector_store_config=vector_store_config) vector_store_config=vector_store_config)
embedding_engine.knowledge_embedding() embedding_engine.knowledge_embedding()
If you want to add your text_splitter, do this:
::
url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html"
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100, chunk_overlap=50
)
embedding_engine = EmbeddingEngine(
knowledge_source=url,
knowledge_type=KnowledgeType.URL.value,
model_name=embedding_model,
vector_store_config=vector_store_config,
text_splitter=text_splitter
)
4.init Document Type EmbeddingEngine api and embedding your document into vector store in your code. 4.init Document Type EmbeddingEngine api and embedding your document into vector store in your code.
Document type can be .txt, .pdf, .md, .doc, .ppt. Document type can be .txt, .pdf, .md, .doc, .ppt.

View File

@ -1,49 +0,0 @@
# Knownledge based qa
Chat with your own knowledge is a very interesting thing. In the usage scenarios of this chapter, we will introduce how to build your own knowledge base through the knowledge base API. Firstly, building a knowledge store can currently be initialized by executing "python tool/knowledge_init.py" to initialize the content of your own knowledge base, which was introduced in the previous knowledge base module. Of course, you can also call our provided knowledge embedding API to store knowledge.
We currently support many document formats: txt, pdf, md, html, doc, ppt, and url.
```
vector_store_config = {
"vector_store_name": name
}
file_path = "your file path"
embedding_engine = EmbeddingEngine(file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config)
embedding_engine.knowledge_embedding()
```
Now we currently support vector databases: Chroma (default) and Milvus. You can switch between them by modifying the "VECTOR_STORE_TYPE" field in the .env file.
```
#*******************************************************************#
#** VECTOR STORE SETTINGS **#
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
#MILVUS_URL=127.0.0.1
#MILVUS_PORT=19530
```
Below is an example of using the knowledge base API to query knowledge:
```
vector_store_config = {
"vector_store_name": your_name,
"vector_store_type": "Chroma",
"chroma_persist_path": "your_persist_dir",
}
integrate
query = "your query"
embedding_model = "your_model_path/all-MiniLM-L6-v2"
embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config)
embedding_engine.similar_search(query, 10)
```

View File

@ -2,6 +2,7 @@ from typing import Optional
from chromadb.errors import NotEnoughElementsException from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import TextSplitter
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
@ -21,6 +22,7 @@ class EmbeddingEngine:
vector_store_config, vector_store_config,
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value, knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
knowledge_source: Optional[str] = None, knowledge_source: Optional[str] = None,
text_splitter: Optional[TextSplitter] = None,
): ):
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source""" """Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
self.knowledge_source = knowledge_source self.knowledge_source = knowledge_source
@ -29,6 +31,7 @@ class EmbeddingEngine:
self.knowledge_type = knowledge_type self.knowledge_type = knowledge_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings self.vector_store_config["embeddings"] = self.embeddings
self.text_splitter = text_splitter
def knowledge_embedding(self): def knowledge_embedding(self):
"""source embedding is chain process.read->text_split->data_process->index_store""" """source embedding is chain process.read->text_split->data_process->index_store"""
@ -47,7 +50,10 @@ class EmbeddingEngine:
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
return get_knowledge_embedding( return get_knowledge_embedding(
self.knowledge_type, self.knowledge_source, self.vector_store_config self.knowledge_type,
self.knowledge_source,
self.vector_store_config,
self.text_splitter,
) )
def similar_search(self, text, topk): def similar_search(self, text, topk):

View File

@ -40,7 +40,9 @@ class KnowledgeType(Enum):
YOUTUBE = "YOUTUBE" YOUTUBE = "YOUTUBE"
def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config): def get_knowledge_embedding(
knowledge_type, knowledge_source, vector_store_config, text_splitter
):
match knowledge_type: match knowledge_type:
case KnowledgeType.DOCUMENT.value: case KnowledgeType.DOCUMENT.value:
extension = "." + knowledge_source.rsplit(".", 1)[-1] extension = "." + knowledge_source.rsplit(".", 1)[-1]
@ -49,6 +51,7 @@ def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_confi
embedding = knowledge_class( embedding = knowledge_class(
knowledge_source, knowledge_source,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
text_splitter=text_splitter,
**knowledge_args, **knowledge_args,
) )
return embedding return embedding
@ -57,12 +60,14 @@ def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_confi
embedding = URLEmbedding( embedding = URLEmbedding(
file_path=knowledge_source, file_path=knowledge_source,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
text_splitter=text_splitter,
) )
return embedding return embedding
case KnowledgeType.TEXT.value: case KnowledgeType.TEXT.value:
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=knowledge_source, file_path=knowledge_source,
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
text_splitter=text_splitter,
) )
return embedding return embedding
case KnowledgeType.OSS.value: case KnowledgeType.OSS.value:

View File

@ -1,6 +1,8 @@
import threading import threading
from datetime import datetime from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter, SpacyTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
@ -122,6 +124,24 @@ class KnowledgeService:
raise Exception( raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync" f" doc:{doc.doc_name} status is {doc.status}, can not sync"
) )
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
try:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=100,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
)
client = EmbeddingEngine( client = EmbeddingEngine(
knowledge_source=doc.content, knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(), knowledge_type=doc.doc_type.upper(),
@ -131,6 +151,7 @@ class KnowledgeService:
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}, },
text_splitter=text_splitter,
) )
chunk_docs = client.read() chunk_docs = client.read()
# update document status # update document status