feat:embedding_engine add text_splitter param

This commit is contained in:
aries_ckt
2023-07-12 18:01:22 +08:00
parent 30adbaf4fd
commit 56c1947eda
6 changed files with 53 additions and 51 deletions

View File

@@ -2,6 +2,7 @@ from typing import Optional
from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import TextSplitter
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType
from pilot.vector_store.connector import VectorStoreConnector
@@ -21,6 +22,7 @@ class EmbeddingEngine:
vector_store_config,
knowledge_type: Optional[str] = KnowledgeType.DOCUMENT.value,
knowledge_source: Optional[str] = None,
text_splitter: Optional[TextSplitter] = None,
):
"""Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source"""
self.knowledge_source = knowledge_source
@@ -29,6 +31,7 @@ class EmbeddingEngine:
self.knowledge_type = knowledge_type
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
self.vector_store_config["embeddings"] = self.embeddings
self.text_splitter = text_splitter
def knowledge_embedding(self):
"""source embedding is chain process.read->text_split->data_process->index_store"""
@@ -47,7 +50,10 @@ class EmbeddingEngine:
def init_knowledge_embedding(self):
return get_knowledge_embedding(
self.knowledge_type, self.knowledge_source, self.vector_store_config
self.knowledge_type,
self.knowledge_source,
self.vector_store_config,
self.text_splitter,
)
def similar_search(self, text, topk):