refactor:refactor knowledge api

1.delete CFG in embedding_engine api
2.add a text_splitter param in embedding_engine api
This commit is contained in:
aries_ckt
2023-07-11 16:33:48 +08:00
parent 6ff7ef9da4
commit e6aa46fc87
24 changed files with 161 additions and 151 deletions

View File

@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from chromadb.errors import NotEnoughElementsException
from pilot.configs.config import Config
from langchain.text_splitter import TextSplitter
from pilot.vector_store.connector import VectorStoreConnector
registered_methods = []
CFG = Config()
def register(method):
@@ -25,12 +25,14 @@ class SourceEmbedding(ABC):
def __init__(
self,
file_path,
vector_store_config,
vector_store_config: {},
text_splitter: TextSplitter = None,
embedding_args: Optional[Dict] = None,
):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.vector_store_config = vector_store_config
self.text_splitter = text_splitter
self.embedding_args = embedding_args
self.embeddings = vector_store_config["embeddings"]
@@ -44,8 +46,8 @@ class SourceEmbedding(ABC):
"""pre process data."""
@register
def text_split(self, text):
"""text split chunk"""
def text_splitter(self, text_splitter: TextSplitter):
"""add text split chunk"""
pass
@register
@@ -57,7 +59,7 @@ class SourceEmbedding(ABC):
def index_to_store(self, docs):
"""index to vector store"""
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
self.vector_store_config["vector_store_type"], self.vector_store_config
)
return self.vector_client.load_document(docs)
@@ -65,7 +67,7 @@ class SourceEmbedding(ABC):
def similar_search(self, doc, topk):
"""vector store similarity_search"""
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
self.vector_store_config["vector_store_type"], self.vector_store_config
)
try:
ans = self.vector_client.similar_search(doc, topk)
@@ -75,7 +77,7 @@ class SourceEmbedding(ABC):
def vector_name_exist(self):
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
self.vector_store_config["vector_store_type"], self.vector_store_config
)
return self.vector_client.vector_name_exists()