mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 04:49:26 +00:00
feature:add milvus store
This commit is contained in:
@@ -48,3 +48,5 @@ DB_SETTINGS = {
|
|||||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store")
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
|
KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
|
||||||
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
|
KNOWLEDGE_CHUNK_SPLIT_SIZE = 100
|
||||||
|
VECTOR_STORE_TYPE = "milvus"
|
||||||
|
VECTOR_STORE_CONFIG = {"url": "127.0.0.1", "port": "19530"}
|
||||||
|
@@ -19,7 +19,8 @@ from langchain import PromptTemplate
|
|||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
|
from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \
|
||||||
|
VECTOR_STORE_CONFIG
|
||||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
from pilot.connections.mysql import MySQLOperator
|
from pilot.connections.mysql import MySQLOperator
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
@@ -267,12 +268,16 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
if mode == conversation_types["custome"] and not db_selector:
|
if mode == conversation_types["custome"] and not db_selector:
|
||||||
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
|
# persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"])
|
||||||
print("vector store path: ", persist_dir)
|
print("vector store type: ", VECTOR_STORE_CONFIG)
|
||||||
|
print("vector store name: ", vector_store_name["vs_name"])
|
||||||
|
vector_store_config = VECTOR_STORE_CONFIG
|
||||||
|
vector_store_config["vector_store_name"] = vector_store_name["vs_name"]
|
||||||
|
vector_store_config["text_field"] = "content"
|
||||||
|
vector_store_config["vector_store_path"] = KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"],
|
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
local_persist=False,
|
local_persist=False,
|
||||||
vector_store_config={"vector_store_name": vector_store_name["vs_name"],
|
vector_store_config=vector_store_config)
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
context = [d.page_content for d in docs]
|
context = [d.page_content for d in docs]
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
|
from langchain.document_loaders import TextLoader, markdown
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||||
@@ -12,6 +12,7 @@ from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
|||||||
import markdown
|
import markdown
|
||||||
|
|
||||||
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader
|
from pilot.source_embedding.pdf_loader import UnstructuredPaddlePDFLoader
|
||||||
|
from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeEmbedding:
|
class KnowledgeEmbedding:
|
||||||
@@ -20,7 +21,7 @@ class KnowledgeEmbedding:
|
|||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.vector_store_config = vector_store_config
|
self.vector_store_config = vector_store_config
|
||||||
self.vector_store_type = "default"
|
self.file_type = "default"
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
self.local_persist = local_persist
|
self.local_persist = local_persist
|
||||||
if not self.local_persist:
|
if not self.local_persist:
|
||||||
@@ -42,7 +43,7 @@ class KnowledgeEmbedding:
|
|||||||
elif self.file_path.endswith(".csv"):
|
elif self.file_path.endswith(".csv"):
|
||||||
embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name,
|
embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name,
|
||||||
vector_store_config=self.vector_store_config)
|
vector_store_config=self.vector_store_config)
|
||||||
elif self.vector_store_type == "default":
|
elif self.file_type == "default":
|
||||||
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
|
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
@@ -52,6 +53,8 @@ class KnowledgeEmbedding:
|
|||||||
|
|
||||||
def knowledge_persist_initialization(self, append_mode):
|
def knowledge_persist_initialization(self, append_mode):
|
||||||
vector_name = self.vector_store_config["vector_store_name"]
|
vector_name = self.vector_store_config["vector_store_name"]
|
||||||
|
documents = self._load_knownlege(self.file_path)
|
||||||
|
if self.vector_store_config["vector_store_type"] == "Chroma":
|
||||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb")
|
persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb")
|
||||||
print("vector db path: ", persist_dir)
|
print("vector db path: ", persist_dir)
|
||||||
if os.path.exists(persist_dir):
|
if os.path.exists(persist_dir):
|
||||||
@@ -66,11 +69,17 @@ class KnowledgeEmbedding:
|
|||||||
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||||
else:
|
else:
|
||||||
print(vector_name + " is new vector store, knowledge begin load...")
|
print(vector_name + " is new vector store, knowledge begin load...")
|
||||||
documents = self._load_knownlege(self.file_path)
|
|
||||||
vector_store = Chroma.from_documents(documents=documents,
|
vector_store = Chroma.from_documents(documents=documents,
|
||||||
embedding=self.embeddings,
|
embedding=self.embeddings,
|
||||||
persist_directory=persist_dir)
|
persist_directory=persist_dir)
|
||||||
vector_store.persist()
|
vector_store.persist()
|
||||||
|
|
||||||
|
elif self.vector_store_config["vector_store_type"] == "milvus":
|
||||||
|
vector_store = MilvusStore({"url": self.vector_store_config["url"],
|
||||||
|
"port": self.vector_store_config["port"],
|
||||||
|
"embedding": self.embeddings})
|
||||||
|
vector_store.init_schema_and_load(vector_name, documents)
|
||||||
|
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
def _load_knownlege(self, path):
|
def _load_knownlege(self, path):
|
||||||
|
@@ -5,9 +5,14 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
from pilot.configs.model_config import VECTOR_STORE_TYPE, VECTOR_STORE_CONFIG
|
||||||
|
from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
|
||||||
registered_methods = []
|
registered_methods = []
|
||||||
|
|
||||||
|
|
||||||
@@ -29,6 +34,17 @@ class SourceEmbedding(ABC):
|
|||||||
self.vector_store_config = vector_store_config
|
self.vector_store_config = vector_store_config
|
||||||
self.embedding_args = embedding_args
|
self.embedding_args = embedding_args
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
|
|
||||||
|
if VECTOR_STORE_TYPE == "milvus":
|
||||||
|
print(VECTOR_STORE_CONFIG)
|
||||||
|
if self.vector_store_config.get("text_field") is None:
|
||||||
|
self.vector_store_client = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
|
||||||
|
"port": VECTOR_STORE_CONFIG["port"],
|
||||||
|
"embedding": self.embeddings})
|
||||||
|
else:
|
||||||
|
self.vector_store_client = Milvus(embedding_function=self.embeddings, collection_name=self.vector_store_config["vector_store_name"], text_field="content",
|
||||||
|
connection_args={"host": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"]})
|
||||||
|
else:
|
||||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
||||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
self.vector_store_config["vector_store_name"] + ".vectordb")
|
||||||
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||||
@@ -54,11 +70,19 @@ class SourceEmbedding(ABC):
|
|||||||
@register
|
@register
|
||||||
def index_to_store(self, docs):
|
def index_to_store(self, docs):
|
||||||
"""index to vector store"""
|
"""index to vector store"""
|
||||||
|
|
||||||
|
if VECTOR_STORE_TYPE == "chroma":
|
||||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
||||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
self.vector_store_config["vector_store_name"] + ".vectordb")
|
||||||
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
||||||
self.vector_store.persist()
|
self.vector_store.persist()
|
||||||
|
|
||||||
|
elif VECTOR_STORE_TYPE == "milvus":
|
||||||
|
self.vector_store = MilvusStore({"url": VECTOR_STORE_CONFIG["url"],
|
||||||
|
"port": VECTOR_STORE_CONFIG["port"],
|
||||||
|
"embedding": self.embeddings})
|
||||||
|
self.vector_store.init_schema_and_load(self.vector_store_config["vector_store_name"], docs)
|
||||||
|
|
||||||
@register
|
@register
|
||||||
def similar_search(self, doc, topk):
|
def similar_search(self, doc, topk):
|
||||||
"""vector store similarity_search"""
|
"""vector store similarity_search"""
|
||||||
|
@@ -1,31 +1,35 @@
|
|||||||
|
from typing import List, Optional, Iterable
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
|
from pymilvus import DataType, FieldSchema, CollectionSchema, connections, Collection
|
||||||
|
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
|
||||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||||
|
|
||||||
|
|
||||||
class MilvusStore(VectorStoreBase):
|
class MilvusStore(VectorStoreBase):
|
||||||
def __init__(self, cfg: {}) -> None:
|
def __init__(self, ctx: {}) -> None:
|
||||||
"""Construct a milvus memory storage connection.
|
"""init a milvus storage connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (Config): MilvusStore global config.
|
ctx ({}): MilvusStore global config.
|
||||||
"""
|
"""
|
||||||
# self.configure(cfg)
|
# self.configure(cfg)
|
||||||
|
|
||||||
connect_kwargs = {}
|
connect_kwargs = {}
|
||||||
self.uri = None
|
self.uri = None
|
||||||
self.uri = cfg["url"]
|
self.uri = ctx["url"]
|
||||||
self.port = cfg["port"]
|
self.port = ctx["port"]
|
||||||
self.username = cfg.get("username", None)
|
self.username = ctx.get("username", None)
|
||||||
self.password = cfg.get("password", None)
|
self.password = ctx.get("password", None)
|
||||||
self.collection_name = cfg["table_name"]
|
self.collection_name = ctx.get("table_name", None)
|
||||||
self.password = cfg.get("secure", None)
|
self.secure = ctx.get("secure", None)
|
||||||
|
self.model_config = ctx.get("model_config", None)
|
||||||
|
self.embedding = ctx.get("embedding", None)
|
||||||
|
self.fields = []
|
||||||
|
|
||||||
# use HNSW by default.
|
# use HNSW by default.
|
||||||
self.index_params = {
|
self.index_params = {
|
||||||
"metric_type": "IP",
|
"metric_type": "L2",
|
||||||
"index_type": "HNSW",
|
"index_type": "HNSW",
|
||||||
"params": {"M": 8, "efConstruction": 64},
|
"params": {"M": 8, "efConstruction": 64},
|
||||||
}
|
}
|
||||||
@@ -39,20 +43,144 @@ class MilvusStore(VectorStoreBase):
|
|||||||
connect_kwargs["password"] = self.password
|
connect_kwargs["password"] = self.password
|
||||||
|
|
||||||
connections.connect(
|
connections.connect(
|
||||||
**connect_kwargs,
|
|
||||||
host=self.uri or "127.0.0.1",
|
host=self.uri or "127.0.0.1",
|
||||||
port=self.port or "19530",
|
port=self.port or "19530",
|
||||||
alias="default"
|
alias="default"
|
||||||
# secure=self.secure,
|
# secure=self.secure,
|
||||||
)
|
)
|
||||||
|
if self.collection_name is not None:
|
||||||
|
self.col = Collection(self.collection_name)
|
||||||
|
schema = self.col.schema
|
||||||
|
for x in schema.fields:
|
||||||
|
self.fields.append(x.name)
|
||||||
|
if x.auto_id:
|
||||||
|
self.fields.remove(x.name)
|
||||||
|
if x.is_primary:
|
||||||
|
self.primary_field = x.name
|
||||||
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
|
self.vector_field = x.name
|
||||||
|
|
||||||
self.init_schema()
|
|
||||||
|
# self.init_schema()
|
||||||
|
# self.init_collection_schema()
|
||||||
|
|
||||||
|
def init_schema_and_load(self, vector_name, documents):
|
||||||
|
"""Create a Milvus collection, indexes it with HNSW, load document.
|
||||||
|
Args:
|
||||||
|
documents (List[str]): Text to insert.
|
||||||
|
vector_name (Embeddings): your collection name.
|
||||||
|
Returns:
|
||||||
|
VectorStore: The MilvusStore vector store.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pymilvus import (
|
||||||
|
Collection,
|
||||||
|
CollectionSchema,
|
||||||
|
DataType,
|
||||||
|
FieldSchema,
|
||||||
|
connections,
|
||||||
|
)
|
||||||
|
from pymilvus.orm.types import infer_dtype_bydata
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
|
# Connect to Milvus instance
|
||||||
|
if not connections.has_connection("default"):
|
||||||
|
connections.connect(
|
||||||
|
host=self.uri or "127.0.0.1",
|
||||||
|
port=self.port or "19530",
|
||||||
|
alias="default"
|
||||||
|
# secure=self.secure,
|
||||||
|
)
|
||||||
|
texts = [d.page_content for d in documents]
|
||||||
|
metadatas = [d.metadata for d in documents]
|
||||||
|
embeddings = self.embedding.embed_query(texts[0])
|
||||||
|
dim = len(embeddings)
|
||||||
|
# Generate unique names
|
||||||
|
primary_field = "pk_id"
|
||||||
|
vector_field = "vector"
|
||||||
|
text_field = "content"
|
||||||
|
self.text_field = text_field
|
||||||
|
collection_name = vector_name
|
||||||
|
fields = []
|
||||||
|
# Determine metadata schema
|
||||||
|
# if metadatas:
|
||||||
|
# # Check if all metadata keys line up
|
||||||
|
# key = metadatas[0].keys()
|
||||||
|
# for x in metadatas:
|
||||||
|
# if key != x.keys():
|
||||||
|
# raise ValueError(
|
||||||
|
# "Mismatched metadata. "
|
||||||
|
# "Make sure all metadata has the same keys and datatype."
|
||||||
|
# )
|
||||||
|
# # Create FieldSchema for each entry in singular metadata.
|
||||||
|
# for key, value in metadatas[0].items():
|
||||||
|
# # Infer the corresponding datatype of the metadata
|
||||||
|
# dtype = infer_dtype_bydata(value)
|
||||||
|
# if dtype == DataType.UNKNOWN:
|
||||||
|
# raise ValueError(f"Unrecognized datatype for {key}.")
|
||||||
|
# elif dtype == DataType.VARCHAR:
|
||||||
|
# # Find out max length text based metadata
|
||||||
|
# max_length = 0
|
||||||
|
# for subvalues in metadatas:
|
||||||
|
# max_length = max(max_length, len(subvalues[key]))
|
||||||
|
# fields.append(
|
||||||
|
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# fields.append(FieldSchema(key, dtype))
|
||||||
|
|
||||||
|
# Find out max length of texts
|
||||||
|
max_length = 0
|
||||||
|
for y in texts:
|
||||||
|
max_length = max(max_length, len(y))
|
||||||
|
# Create the text field
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
|
||||||
|
)
|
||||||
|
# Create the primary key field
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
||||||
|
)
|
||||||
|
# Create the vector field
|
||||||
|
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||||
|
# Create the schema for the collection
|
||||||
|
schema = CollectionSchema(fields)
|
||||||
|
# Create the collection
|
||||||
|
collection = Collection(collection_name, schema)
|
||||||
|
self.col = collection
|
||||||
|
# Index parameters for the collection
|
||||||
|
index = self.index_params
|
||||||
|
# Create the index
|
||||||
|
collection.create_index(vector_field, index)
|
||||||
|
# Create the VectorStore
|
||||||
|
# milvus = cls(
|
||||||
|
# embedding,
|
||||||
|
# kwargs.get("connection_args", {"port": 19530}),
|
||||||
|
# collection_name,
|
||||||
|
# text_field,
|
||||||
|
# )
|
||||||
|
# Add the texts.
|
||||||
|
schema = collection.schema
|
||||||
|
for x in schema.fields:
|
||||||
|
self.fields.append(x.name)
|
||||||
|
if x.auto_id:
|
||||||
|
self.fields.remove(x.name)
|
||||||
|
if x.is_primary:
|
||||||
|
self.primary_field = x.name
|
||||||
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
|
self.vector_field = x.name
|
||||||
|
self._add_texts(texts, metadatas)
|
||||||
|
|
||||||
|
return self.collection_name
|
||||||
|
|
||||||
def init_schema(self) -> None:
|
def init_schema(self) -> None:
|
||||||
"""Initialize collection in milvus database."""
|
"""Initialize collection in milvus database."""
|
||||||
fields = [
|
fields = [
|
||||||
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=384),
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
|
||||||
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -75,7 +203,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
info = self.collection.describe()
|
info = self.collection.describe()
|
||||||
self.collection.load()
|
self.collection.load()
|
||||||
|
|
||||||
def insert(self, text) -> str:
|
def insert(self, text, model_config) -> str:
|
||||||
"""Add an embedding of data into milvus.
|
"""Add an embedding of data into milvus.
|
||||||
Args:
|
Args:
|
||||||
text (str): The raw text to construct embedding index.
|
text (str): The raw text to construct embedding index.
|
||||||
@@ -83,10 +211,54 @@ class MilvusStore(VectorStoreBase):
|
|||||||
str: log.
|
str: log.
|
||||||
"""
|
"""
|
||||||
# embedding = get_ada_embedding(data)
|
# embedding = get_ada_embedding(data)
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"])
|
embeddings = HuggingFaceEmbeddings(model_name=self.model_config["model_name"])
|
||||||
result = self.collection.insert([embeddings.embed_documents(text), text])
|
result = self.collection.insert([embeddings.embed_documents(text), text])
|
||||||
_text = (
|
_text = (
|
||||||
"Inserting data into memory at primary key: "
|
"Inserting data into memory at primary key: "
|
||||||
f"{result.primary_keys[0]}:\n data: {text}"
|
f"{result.primary_keys[0]}:\n data: {text}"
|
||||||
)
|
)
|
||||||
return _text
|
return _text
|
||||||
|
|
||||||
|
def _add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
partition_name: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Insert text data into Milvus.
|
||||||
|
Args:
|
||||||
|
texts (Iterable[str]): The text being embedded and inserted.
|
||||||
|
metadatas (Optional[List[dict]], optional): The metadata that
|
||||||
|
corresponds to each insert. Defaults to None.
|
||||||
|
partition_name (str, optional): The partition of the collection
|
||||||
|
to insert data into. Defaults to None.
|
||||||
|
timeout: specified timeout.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: The resulting keys for each inserted element.
|
||||||
|
"""
|
||||||
|
insert_dict: Any = {self.text_field: list(texts)}
|
||||||
|
try:
|
||||||
|
insert_dict[self.vector_field] = self.embedding.embed_documents(
|
||||||
|
list(texts)
|
||||||
|
)
|
||||||
|
except NotImplementedError:
|
||||||
|
insert_dict[self.vector_field] = [
|
||||||
|
self.embedding.embed_query(x) for x in texts
|
||||||
|
]
|
||||||
|
# Collect the metadata into the insert dict.
|
||||||
|
if len(self.fields) > 2 and metadatas is not None:
|
||||||
|
for d in metadatas:
|
||||||
|
for key, value in d.items():
|
||||||
|
if key in self.fields:
|
||||||
|
insert_dict.setdefault(key, []).append(value)
|
||||||
|
# Convert dict to list of lists for insertion
|
||||||
|
insert_list = [insert_dict[x] for x in self.fields]
|
||||||
|
# Insert into the collection.
|
||||||
|
res = self.col.insert(
|
||||||
|
insert_list, partition_name=partition_name, timeout=timeout
|
||||||
|
)
|
||||||
|
# Flush to make sure newly inserted is immediately searchable.
|
||||||
|
self.col.flush()
|
||||||
|
return res.primary_keys
|
||||||
|
@@ -60,6 +60,7 @@ gTTS==2.3.1
|
|||||||
langchain
|
langchain
|
||||||
nltk
|
nltk
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
|
pymilvus
|
||||||
|
|
||||||
# Testing dependencies
|
# Testing dependencies
|
||||||
pytest
|
pytest
|
||||||
|
@@ -2,8 +2,10 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, \
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
|
from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
|
|
||||||
|
|
||||||
@@ -12,15 +14,15 @@ class LocalKnowledgeInit:
|
|||||||
model_name = LLM_MODEL_CONFIG["text2vec"]
|
model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K
|
top_k: int = VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, vector_store_config) -> None:
|
||||||
pass
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
def knowledge_persist(self, file_path, vector_name, append_mode):
|
def knowledge_persist(self, file_path, append_mode):
|
||||||
""" knowledge persist """
|
""" knowledge persist """
|
||||||
kv = KnowledgeEmbedding(
|
kv = KnowledgeEmbedding(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
vector_store_config= {"vector_store_name":vector_name, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
vector_store_config= self.vector_store_config)
|
||||||
vector_store = kv.knowledge_persist_initialization(append_mode)
|
vector_store = kv.knowledge_persist_initialization(append_mode)
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
@@ -34,11 +36,15 @@ class LocalKnowledgeInit:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--vector_name", type=str, default="default")
|
parser.add_argument("--vector_name", type=str, default="keting")
|
||||||
parser.add_argument("--append", type=bool, default=False)
|
parser.add_argument("--append", type=bool, default=False)
|
||||||
|
parser.add_argument("--store_type", type=str, default="Chroma")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
vector_name = args.vector_name
|
vector_name = args.vector_name
|
||||||
append_mode = args.append
|
append_mode = args.append
|
||||||
kv = LocalKnowledgeInit()
|
store_type = args.store_type
|
||||||
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, vector_name=vector_name, append_mode=append_mode)
|
vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name, "vector_store_type":store_type}
|
||||||
|
print(vector_store_config)
|
||||||
|
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||||
|
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||||
print("your knowledge embedding success...")
|
print("your knowledge embedding success...")
|
Reference in New Issue
Block a user