refactor:refactor knowledge api

1.delete CFG in embedding_engine api
2.add a text_splitter param in embedding_engine api
This commit is contained in:
aries_ckt 2023-07-11 16:33:48 +08:00
parent 6ff7ef9da4
commit e6aa46fc87
24 changed files with 161 additions and 151 deletions

View File

@ -26,7 +26,7 @@ before execution:
:: ::
url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html" url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html"
embedding_model = "text2vec" embedding_model = "your_model_path/all-MiniLM-L6-v2"
vector_store_config = { vector_store_config = {
"vector_store_name": your_name, "vector_store_name": your_name,
} }
@ -43,9 +43,11 @@ Document type can be .txt, .pdf, .md, .doc, .ppt.
:: ::
document_path = "your_path/test.md" document_path = "your_path/test.md"
embedding_model = "text2vec" embedding_model = "your_model_path/all-MiniLM-L6-v2"
vector_store_config = { vector_store_config = {
"vector_store_name": your_name, "vector_store_name": your_name,
"vector_store_type": "Chroma",
"chroma_persist_path": "your_persist_dir",
} }
embedding_engine = EmbeddingEngine( embedding_engine = EmbeddingEngine(
knowledge_source=document_path, knowledge_source=document_path,
@ -59,7 +61,7 @@ Document type can be .txt, .pdf, .md, .doc, .ppt.
:: ::
raw_text = "a long passage" raw_text = "a long passage"
embedding_model = "text2vec" embedding_model = "your_model_path/all-MiniLM-L6-v2"
vector_store_config = { vector_store_config = {
"vector_store_name": your_name, "vector_store_name": your_name,
} }

View File

@ -32,11 +32,17 @@ Below is an example of using the knowledge base API to query knowledge:
``` ```
vector_store_config = { vector_store_config = {
"vector_store_name": name "vector_store_name": your_name,
"vector_store_type": "Chroma",
"chroma_persist_path": "your_persist_dir",
} }
integrate
query = "your query" query = "your query"
embedding_model = "your_model_path/all-MiniLM-L6-v2"
embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config) embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config)
embedding_engine.similar_search(query, 10) embedding_engine.similar_search(query, 10)

View File

@ -9,17 +9,12 @@ from pilot.embedding_engine import SourceEmbedding, register
class CSVEmbedding(SourceEmbedding): class CSVEmbedding(SourceEmbedding):
"""csv embedding for read csv document.""" """csv embedding for read csv document."""
def __init__( def __init__(self, file_path, vector_store_config, text_splitter=None):
self,
file_path,
vector_store_config,
embedding_args: Optional[Dict] = None,
):
"""Initialize with csv path.""" """Initialize with csv path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.embedding_args = embedding_args self.text_splitter = text_splitter or None
@register @register
def read(self): def read(self):

View File

@ -3,12 +3,9 @@ from typing import Optional
from chromadb.errors import NotEnoughElementsException from chromadb.errors import NotEnoughElementsException
from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType from pilot.embedding_engine.knowledge_type import get_knowledge_embedding, KnowledgeType
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
CFG = Config()
class EmbeddingEngine: class EmbeddingEngine:
def __init__( def __init__(
@ -45,7 +42,7 @@ class EmbeddingEngine:
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config self.vector_store_config["vector_store_type"], self.vector_store_config
) )
try: try:
ans = vector_client.similar_search(text, topk) ans = vector_client.similar_search(text, topk)
@ -55,12 +52,12 @@ class EmbeddingEngine:
def vector_exist(self): def vector_exist(self):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config self.vector_store_config["vector_store_type"], self.vector_store_config
) )
return vector_client.vector_name_exists() return vector_client.vector_name_exists()
def delete_by_ids(self, ids): def delete_by_ids(self, ids):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config self.vector_store_config["vector_store_type"], self.vector_store_config
) )
vector_client.delete_by_ids(ids=ids) vector_client.delete_by_ids(ids=ids)

View File

@ -11,6 +11,7 @@ from pilot.embedding_engine.word_embedding import WordEmbedding
DocumentEmbeddingType = { DocumentEmbeddingType = {
".txt": (MarkdownEmbedding, {}), ".txt": (MarkdownEmbedding, {}),
".md": (MarkdownEmbedding, {}), ".md": (MarkdownEmbedding, {}),
".html": (MarkdownEmbedding, {}),
".pdf": (PDFEmbedding, {}), ".pdf": (PDFEmbedding, {}),
".doc": (WordEmbedding, {}), ".doc": (WordEmbedding, {}),
".docx": (WordEmbedding, {}), ".docx": (WordEmbedding, {}),
@ -25,7 +26,18 @@ class KnowledgeType(Enum):
URL = "URL" URL = "URL"
TEXT = "TEXT" TEXT = "TEXT"
OSS = "OSS" OSS = "OSS"
S3 = "S3"
NOTION = "NOTION" NOTION = "NOTION"
MYSQL = "MYSQL"
TIDB = "TIDB"
CLICKHOUSE = "CLICKHOUSE"
OCEANBASE = "OCEANBASE"
ELASTICSEARCH = "ELASTICSEARCH"
HIVE = "HIVE"
PRESTO = "PRESTO"
KAFKA = "KAFKA"
SPARK = "SPARK"
YOUTUBE = "YOUTUBE"
def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config): def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config):
@ -55,8 +67,29 @@ def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_confi
return embedding return embedding
case KnowledgeType.OSS.value: case KnowledgeType.OSS.value:
raise Exception("OSS have not integrate") raise Exception("OSS have not integrate")
case KnowledgeType.S3.value:
raise Exception("S3 have not integrate")
case KnowledgeType.NOTION.value: case KnowledgeType.NOTION.value:
raise Exception("NOTION have not integrate") raise Exception("NOTION have not integrate")
case KnowledgeType.MYSQL.value:
raise Exception("MYSQL have not integrate")
case KnowledgeType.TIDB.value:
raise Exception("TIDB have not integrate")
case KnowledgeType.CLICKHOUSE.value:
raise Exception("CLICKHOUSE have not integrate")
case KnowledgeType.OCEANBASE.value:
raise Exception("OCEANBASE have not integrate")
case KnowledgeType.ELASTICSEARCH.value:
raise Exception("ELASTICSEARCH have not integrate")
case KnowledgeType.HIVE.value:
raise Exception("HIVE have not integrate")
case KnowledgeType.PRESTO.value:
raise Exception("PRESTO have not integrate")
case KnowledgeType.KAFKA.value:
raise Exception("KAFKA have not integrate")
case KnowledgeType.SPARK.value:
raise Exception("SPARK have not integrate")
case KnowledgeType.YOUTUBE.value:
raise Exception("YOUTUBE have not integrate")
case _: case _:
raise Exception("unknown knowledge type") raise Exception("unknown knowledge type")

View File

@ -12,46 +12,38 @@ from langchain.text_splitter import (
RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitter,
) )
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader
CFG = Config()
class MarkdownEmbedding(SourceEmbedding): class MarkdownEmbedding(SourceEmbedding):
"""markdown embedding for read markdown document.""" """markdown embedding for read markdown document."""
def __init__(self, file_path, vector_store_config): def __init__(self, file_path, vector_store_config, text_splitter=None):
"""Initialize with markdown path.""" """Initialize raw text word path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter or None
# self.encoding = encoding # self.encoding = encoding
@register @register
def read(self): def read(self):
"""Load from markdown path.""" """Load from markdown path."""
loader = EncodeTextLoader(self.file_path) loader = EncodeTextLoader(self.file_path)
if self.text_splitter is None:
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
try: try:
text_splitter = SpacyTextSplitter( self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_size=100,
chunk_overlap=100, chunk_overlap=100,
) )
except Exception: except Exception:
text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 chunk_size=100, chunk_overlap=50
) )
return loader.load_and_split(text_splitter)
return loader.load_and_split(self.text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -6,51 +6,36 @@ from langchain.document_loaders import PyPDFLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
CFG = Config()
class PDFEmbedding(SourceEmbedding): class PDFEmbedding(SourceEmbedding):
"""pdf embedding for read pdf document.""" """pdf embedding for read pdf document."""
def __init__(self, file_path, vector_store_config): def __init__(self, file_path, vector_store_config, text_splitter=None):
"""Initialize with pdf path.""" """Initialize pdf word path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter or None
@register @register
def read(self): def read(self):
"""Load from pdf path.""" """Load from pdf path."""
loader = PyPDFLoader(self.file_path) loader = PyPDFLoader(self.file_path)
# textsplitter = CHNDocumentSplitter( if self.text_splitter is None:
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
# )
# textsplitter = SpacyTextSplitter(
# pipeline="zh_core_web_sm",
# chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
# chunk_overlap=100,
# )
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
try: try:
text_splitter = SpacyTextSplitter( self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_size=100,
chunk_overlap=100, chunk_overlap=100,
) )
except Exception: except Exception:
text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 chunk_size=100, chunk_overlap=50
) )
return loader.load_and_split(text_splitter)
return loader.load_and_split(self.text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -6,48 +6,36 @@ from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
CFG = Config()
class PPTEmbedding(SourceEmbedding): class PPTEmbedding(SourceEmbedding):
"""ppt embedding for read ppt document.""" """ppt embedding for read ppt document."""
def __init__(self, file_path, vector_store_config): def __init__(self, file_path, vector_store_config, text_splitter=None):
"""Initialize with pdf path.""" """Initialize ppt word path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter or None
@register @register
def read(self): def read(self):
"""Load from ppt path.""" """Load from ppt path."""
loader = UnstructuredPowerPointLoader(self.file_path) loader = UnstructuredPowerPointLoader(self.file_path)
# textsplitter = SpacyTextSplitter( if self.text_splitter is None:
# pipeline="zh_core_web_sm",
# chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
# chunk_overlap=200,
# )
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
try: try:
text_splitter = SpacyTextSplitter( self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_size=100,
chunk_overlap=100, chunk_overlap=100,
) )
except Exception: except Exception:
text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 chunk_size=100, chunk_overlap=50
) )
return loader.load_and_split(text_splitter)
return loader.load_and_split(self.text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional
from chromadb.errors import NotEnoughElementsException from chromadb.errors import NotEnoughElementsException
from pilot.configs.config import Config from langchain.text_splitter import TextSplitter
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
registered_methods = [] registered_methods = []
CFG = Config()
def register(method): def register(method):
@ -25,12 +25,14 @@ class SourceEmbedding(ABC):
def __init__( def __init__(
self, self,
file_path, file_path,
vector_store_config, vector_store_config: {},
text_splitter: TextSplitter = None,
embedding_args: Optional[Dict] = None, embedding_args: Optional[Dict] = None,
): ):
"""Initialize with Loader url, model_name, vector_store_config""" """Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter
self.embedding_args = embedding_args self.embedding_args = embedding_args
self.embeddings = vector_store_config["embeddings"] self.embeddings = vector_store_config["embeddings"]
@ -44,8 +46,8 @@ class SourceEmbedding(ABC):
"""pre process data.""" """pre process data."""
@register @register
def text_split(self, text): def text_splitter(self, text_splitter: TextSplitter):
"""text split chunk""" """add text split chunk"""
pass pass
@register @register
@ -57,7 +59,7 @@ class SourceEmbedding(ABC):
def index_to_store(self, docs): def index_to_store(self, docs):
"""index to vector store""" """index to vector store"""
self.vector_client = VectorStoreConnector( self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config self.vector_store_config["vector_store_type"], self.vector_store_config
) )
return self.vector_client.load_document(docs) return self.vector_client.load_document(docs)
@ -65,7 +67,7 @@ class SourceEmbedding(ABC):
def similar_search(self, doc, topk): def similar_search(self, doc, topk):
"""vector store similarity_search""" """vector store similarity_search"""
self.vector_client = VectorStoreConnector( self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config self.vector_store_config["vector_store_type"], self.vector_store_config
) )
try: try:
ans = self.vector_client.similar_search(doc, topk) ans = self.vector_client.similar_search(doc, topk)
@ -75,7 +77,7 @@ class SourceEmbedding(ABC):
def vector_name_exist(self): def vector_name_exist(self):
self.vector_client = VectorStoreConnector( self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config self.vector_store_config["vector_store_type"], self.vector_store_config
) )
return self.vector_client.vector_name_exists() return self.vector_client.vector_name_exists()

View File

@ -8,11 +8,12 @@ from pilot.embedding_engine import SourceEmbedding, register
class StringEmbedding(SourceEmbedding): class StringEmbedding(SourceEmbedding):
"""string embedding for read string document.""" """string embedding for read string document."""
def __init__(self, file_path, vector_store_config): def __init__(self, file_path, vector_store_config, text_splitter=None):
"""Initialize with pdf path.""" """Initialize raw text word path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter or None
@register @register
def read(self): def read(self):

View File

@ -5,43 +5,37 @@ from langchain.document_loaders import WebBaseLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
CFG = Config()
class URLEmbedding(SourceEmbedding): class URLEmbedding(SourceEmbedding):
"""url embedding for read url document.""" """url embedding for read url document."""
def __init__(self, file_path, vector_store_config): def __init__(self, file_path, vector_store_config, text_splitter=None):
"""Initialize with url path.""" """Initialize url word path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter or None
@register @register
def read(self): def read(self):
"""Load from url path.""" """Load from url path."""
loader = WebBaseLoader(web_path=self.file_path) loader = WebBaseLoader(web_path=self.file_path)
if CFG.LANGUAGE == "en": if self.text_splitter is None:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
try: try:
text_splitter = SpacyTextSplitter( self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_size=100,
chunk_overlap=100, chunk_overlap=100,
) )
except Exception: except Exception:
text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 chunk_size=100, chunk_overlap=50
) )
return loader.load_and_split(text_splitter)
return loader.load_and_split(self.text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -6,43 +6,36 @@ from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
CFG = Config()
class WordEmbedding(SourceEmbedding): class WordEmbedding(SourceEmbedding):
"""word embedding for read word document.""" """word embedding for read word document."""
def __init__(self, file_path, vector_store_config): def __init__(self, file_path, vector_store_config, text_splitter=None):
"""Initialize with word path.""" """Initialize with word path."""
super().__init__(file_path, vector_store_config) super().__init__(file_path, vector_store_config, text_splitter=None)
self.file_path = file_path self.file_path = file_path
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.text_splitter = text_splitter or None
@register @register
def read(self): def read(self):
"""Load from word path.""" """Load from word path."""
loader = UnstructuredWordDocumentLoader(self.file_path) loader = UnstructuredWordDocumentLoader(self.file_path)
if CFG.LANGUAGE == "en": if self.text_splitter is None:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
try: try:
text_splitter = SpacyTextSplitter( self.text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm", pipeline="zh_core_web_sm",
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_size=100,
chunk_overlap=100, chunk_overlap=100,
) )
except Exception: except Exception:
text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50 chunk_size=100, chunk_overlap=50
) )
return loader.load_and_split(text_splitter)
return loader.load_and_split(self.text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):

View File

@ -37,8 +37,8 @@ class ChatNewKnowledge(BaseChat):
self.knowledge_name = knowledge_name self.knowledge_name = knowledge_name
vector_store_config = { vector_store_config = {
"vector_store_name": knowledge_name, "vector_store_name": knowledge_name,
"text_field": "content", "vector_store_type": CFG.VECTOR_STORE_TYPE,
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = EmbeddingEngine( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],

View File

@ -38,7 +38,8 @@ class ChatDefaultKnowledge(BaseChat):
) )
vector_store_config = { vector_store_config = {
"vector_store_name": "default", "vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = EmbeddingEngine( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],

View File

@ -38,7 +38,8 @@ class ChatUrlKnowledge(BaseChat):
self.url = url self.url = url
vector_store_config = { vector_store_config = {
"vector_store_name": url.replace(":", ""), "vector_store_name": url.replace(":", ""),
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = EmbeddingEngine( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],

View File

@ -38,7 +38,8 @@ class ChatKnowledge(BaseChat):
) )
vector_store_config = { vector_store_config = {
"vector_store_name": knowledge_space, "vector_store_name": knowledge_space,
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = EmbeddingEngine( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],

View File

@ -2,7 +2,7 @@ import threading
from datetime import datetime from datetime import datetime
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.logs import logger from pilot.logs import logger
from pilot.server.knowledge.chunk_db import ( from pilot.server.knowledge.chunk_db import (
@ -128,6 +128,8 @@ class KnowledgeService:
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={ vector_store_config={
"vector_store_name": space_name, "vector_store_name": space_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}, },
) )
chunk_docs = client.read() chunk_docs = client.read()

View File

@ -665,6 +665,7 @@ def knowledge_embedding_store(vs_id, files):
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config={ vector_store_config={
"vector_store_name": vector_store_name["vs_name"], "vector_store_name": vector_store_name["vs_name"],
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}, },
) )

View File

@ -4,7 +4,7 @@ import uuid
from langchain.embeddings import HuggingFaceEmbeddings, logger from langchain.embeddings import HuggingFaceEmbeddings, logger
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
@ -33,6 +33,8 @@ class DBSummaryClient:
) )
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_summary", "vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
@ -60,6 +62,8 @@ class DBSummaryClient:
) in db_summary_client.get_table_summary().items(): ) in db_summary_client.get_table_summary().items():
table_vector_store_config = { table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts", "vector_store_name": dbname + "_" + table_name + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
@ -73,6 +77,9 @@ class DBSummaryClient:
def get_db_summary(self, dbname, query, topk): def get_db_summary(self, dbname, query, topk):
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_profile",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
@ -86,6 +93,9 @@ class DBSummaryClient:
"""get user query related tables info""" """get user query related tables info"""
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_summary", "vector_store_name": dbname + "_summary",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
@ -109,6 +119,9 @@ class DBSummaryClient:
for table in related_tables: for table in related_tables:
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts", "vector_store_name": dbname + "_" + table + "_ts",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
knowledge_embedding_client = EmbeddingEngine( knowledge_embedding_client = EmbeddingEngine(
file_path="", file_path="",
@ -128,6 +141,8 @@ class DBSummaryClient:
def init_db_profile(self, db_summary_client, dbname, embeddings): def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = { profile_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_profile",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(

View File

@ -1,7 +1,6 @@
import os import os
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase from pilot.vector_store.vector_store_base import VectorStoreBase
@ -13,7 +12,7 @@ class ChromaStore(VectorStoreBase):
self.ctx = ctx self.ctx = ctx
self.embeddings = ctx["embeddings"] self.embeddings = ctx["embeddings"]
self.persist_dir = os.path.join( self.persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb" ctx["chroma_persist_path"], ctx["vector_store_name"] + ".vectordb"
) )
self.vector_store_client = Chroma( self.vector_store_client = Chroma(
persist_directory=self.persist_dir, embedding_function=self.embeddings persist_directory=self.persist_dir, embedding_function=self.embeddings

View File

@ -1,8 +1,8 @@
from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.chroma_store import ChromaStore
# from pilot.vector_store.milvus_store import MilvusStore from pilot.vector_store.milvus_store import MilvusStore
connector = {"Chroma": ChromaStore, "Milvus": None} connector = {"Chroma": ChromaStore, "Milvus": MilvusStore}
class VectorStoreConnector: class VectorStoreConnector:

View File

@ -3,11 +3,9 @@ from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections, utility from pymilvus import Collection, DataType, connections, utility
from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
class MilvusStore(VectorStoreBase): class MilvusStore(VectorStoreBase):
@ -22,10 +20,10 @@ class MilvusStore(VectorStoreBase):
# self.configure(cfg) # self.configure(cfg)
connect_kwargs = {} connect_kwargs = {}
self.uri = CFG.MILVUS_URL self.uri = ctx.get("milvus_url", None)
self.port = CFG.MILVUS_PORT self.port = ctx.get("milvus_port", None)
self.username = CFG.MILVUS_USERNAME self.username = ctx.get("milvus_username", None)
self.password = CFG.MILVUS_PASSWORD self.password = ctx.get("milvus_password", None)
self.collection_name = ctx.get("vector_store_name", None) self.collection_name = ctx.get("vector_store_name", None)
self.secure = ctx.get("secure", None) self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None) self.embedding = ctx.get("embeddings", None)

View File

@ -2,8 +2,12 @@ from pilot import EmbeddingEngine, KnowledgeType
url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html" url = "https://db-gpt.readthedocs.io/en/latest/getting_started/getting_started.html"
embedding_model = "text2vec" embedding_model = "text2vec"
vector_store_type = "Chroma"
chroma_persist_path = "your_persist_path"
vector_store_config = { vector_store_config = {
"vector_store_name": url.replace(":", ""), "vector_store_name": url.replace(":", ""),
"vector_store_type": vector_store_type,
"chroma_persist_path": chroma_persist_path
} }
embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config) embedding_engine = EmbeddingEngine(knowledge_source=url, knowledge_type=KnowledgeType.URL.value, model_name=embedding_model, vector_store_config=vector_store_config)

View File

@ -14,7 +14,7 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
DATASETS_DIR, DATASETS_DIR,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH,
) )
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
@ -68,7 +68,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
vector_name = args.vector_name vector_name = args.vector_name
store_type = CFG.VECTOR_STORE_TYPE store_type = CFG.VECTOR_STORE_TYPE
vector_store_config = {"vector_store_name": vector_name} vector_store_config = {"vector_store_name": vector_name, "vector_store_type": CFG.VECTOR_STORE_TYPE, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH}
print(vector_store_config) print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
kv.knowledge_persist(file_path=DATASETS_DIR) kv.knowledge_persist(file_path=DATASETS_DIR)