mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-10 03:03:57 +00:00
refactor:refactor knowledge api
1.delete CFG in embedding_engine api 2.add a text_splitter param in embedding_engine api
This commit is contained in:
@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from pilot.configs.config import Config
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
registered_methods = []
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def register(method):
|
||||
@@ -25,12 +25,14 @@ class SourceEmbedding(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
vector_store_config,
|
||||
vector_store_config: {},
|
||||
text_splitter: TextSplitter = None,
|
||||
embedding_args: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.vector_store_config = vector_store_config
|
||||
self.text_splitter = text_splitter
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = vector_store_config["embeddings"]
|
||||
|
||||
@@ -44,8 +46,8 @@ class SourceEmbedding(ABC):
|
||||
"""pre process data."""
|
||||
|
||||
@register
|
||||
def text_split(self, text):
|
||||
"""text split chunk"""
|
||||
def text_splitter(self, text_splitter: TextSplitter):
|
||||
"""add text split chunk"""
|
||||
pass
|
||||
|
||||
@register
|
||||
@@ -57,7 +59,7 @@ class SourceEmbedding(ABC):
|
||||
def index_to_store(self, docs):
|
||||
"""index to vector store"""
|
||||
self.vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
return self.vector_client.load_document(docs)
|
||||
|
||||
@@ -65,7 +67,7 @@ class SourceEmbedding(ABC):
|
||||
def similar_search(self, doc, topk):
|
||||
"""vector store similarity_search"""
|
||||
self.vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
try:
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
@@ -75,7 +77,7 @@ class SourceEmbedding(ABC):
|
||||
|
||||
def vector_name_exist(self):
|
||||
self.vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
return self.vector_client.vector_name_exists()
|
||||
|
||||
|
Reference in New Issue
Block a user